FK-TransUNet++ (GPR-SR) — TransUNet with CUP decoder & f–k spectral consistency for GPR super-resolution

In [3]:
import os
import math
import cv2
import numpy as np
from PIL import Image
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

# ==========================
# CONFIG
# ==========================
DATASET_DIR = r"C:\Preet\9000_paired_bscans"  # *_l.png and *_h.png paired images
IMAGE_SIZE = (256, 256)
BATCH_SIZE = 8
EPOCHS = 60
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Running on:", DEVICE)

MODEL_PATH = os.path.join(DATASET_DIR, "fk_transunet++.pth")
RESULTS_DIR = os.path.join(DATASET_DIR, "predictions_fk_transunet")
os.makedirs(RESULTS_DIR, exist_ok=True)

# f–k loss weights / options (tune if needed)
FK_WEIGHT = 0.5           # contribution of f–k term to total loss
FK_USE_LOG = True         # use log magnitude for stability
FK_USE_HANN = True        # apply 2D Hann window before FFT to reduce leakage
FK_KY = (0.05, 0.85)      # rectangular band in normalized ky (temporal freq)
FK_KX = (0.00, 0.90)      # rectangular band in normalized kx (spatial wavenumber)

# ==========================
# DATASET
# ==========================
class GPRDataset(Dataset):
    def __init__(self, x_paths, y_paths):
        self.x_paths = x_paths
        self.y_paths = y_paths

    def __len__(self):
        return len(self.x_paths)

    def __getitem__(self, idx):
        x = np.array(Image.open(self.x_paths[idx]).resize(IMAGE_SIZE), dtype=np.float32) / 255.0
        y = np.array(Image.open(self.y_paths[idx]).resize(IMAGE_SIZE), dtype=np.float32) / 255.0
        x = torch.tensor(x).unsqueeze(0)  # (1,H,W)
        y = torch.tensor(y).unsqueeze(0)
        return x, y

def load_data(dataset_dir):
    low_paths, high_paths = [], []
    for file in os.listdir(dataset_dir):
        if file.endswith("_l.png"):
            low_path = os.path.join(dataset_dir, file)
            high_path = os.path.join(dataset_dir, file.replace("_l.png", "_h.png"))
            if os.path.exists(high_path):
                low_paths.append(low_path)
                high_paths.append(high_path)
    return low_paths, high_paths

all_x, all_y = load_data(DATASET_DIR)

# Split 70/15/15
train_x, temp_x, train_y, temp_y = train_test_split(all_x, all_y, test_size=0.30, random_state=42)
val_x, test_x, val_y, test_y = train_test_split(temp_x, temp_y, test_size=0.50, random_state=42)

print(f"Train: {len(train_x)}, Val: {len(val_x)}, Test: {len(test_x)}")

train_loader = DataLoader(GPRDataset(train_x, train_y), batch_size=BATCH_SIZE, shuffle=True)
val_loader   = DataLoader(GPRDataset(val_x, val_y), batch_size=BATCH_SIZE, shuffle=False)
test_loader  = DataLoader(GPRDataset(test_x, test_y), batch_size=1, shuffle=False)

# ==========================
# MODEL BLOCKS
# ==========================
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.conv(x)

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads=4, ff_dim=512, dropout=0.1):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, embed_dim)
        )
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.drop = nn.Dropout(dropout)

    def forward(self, x):
        attn_out, _ = self.attn(x, x, x)
        x = self.norm1(x + self.drop(attn_out))
        ff_out = self.ff(x)
        x = self.norm2(x + self.drop(ff_out))
        return x

class TransformerBottleneck(nn.Module):
    def __init__(self, in_ch, patch_size=16, num_layers=2, num_heads=4):
        super().__init__()
        self.patch_size = patch_size
        self.embed_dim = in_ch

        self.flatten = nn.Unfold(kernel_size=patch_size, stride=patch_size)
        self.project = nn.Linear(patch_size * patch_size * in_ch, in_ch)

        self.transformer = nn.Sequential(*[
            TransformerBlock(embed_dim=in_ch, num_heads=num_heads) for _ in range(num_layers)
        ])

        self.reconstruct = nn.Linear(in_ch, patch_size * patch_size * in_ch)
        self.fold = None  # initialized at runtime

    def forward(self, x):
        B, C, H, W = x.shape
        patches = self.flatten(x).transpose(1, 2)  # (B, N, patch_dim)
        patches = self.project(patches)
        patches = self.transformer(patches)
        patches = self.reconstruct(patches).transpose(1, 2)
        if self.fold is None:
            self.fold = nn.Fold(output_size=(H, W),
                                kernel_size=self.patch_size,
                                stride=self.patch_size)
        x_reconstructed = self.fold(patches)
        return x_reconstructed

# ---- NEW: CUP upsampler block ----
class UpCUP(nn.Module):
    """Detail-preserving upsampler: conv -> pixelshuffle(2x) -> conv"""
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.up = nn.Sequential(
            nn.Conv2d(in_ch, out_ch * 4, kernel_size=3, padding=1),
            nn.PixelShuffle(2),                         # (H,W) -> (2H,2W)
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
            nn.GELU()
        )
    def forward(self, x):
        return self.up(x)

class TransUNet(nn.Module):
    def __init__(self, in_ch=1, out_ch=1, base_ch=64):
        super().__init__()
        # ================= Encoder =================
        self.enc1 = ConvBlock(in_ch, base_ch)            # 1 -> 64
        self.enc2 = ConvBlock(base_ch, base_ch*2)        # 64 -> 128
        self.enc3 = ConvBlock(base_ch*2, base_ch*4)      # 128 -> 256
        self.pool = nn.MaxPool2d(2)

        # ================= Bottleneck =================
        self.bottleneck = TransformerBottleneck(base_ch*4, patch_size=16, num_layers=2, num_heads=4)

        # ================= Decoder (CUP) =================
        # stage with e3 skip
        self.up3  = UpCUP(base_ch*4, base_ch*2)                         # 256 -> 128 @ H/4
        self.dec3 = ConvBlock(base_ch*2 + base_ch*4, base_ch*2)         # [128 + 256] -> 128
        # stage with e2 skip
        self.up2  = UpCUP(base_ch*2, base_ch)                           # 128 -> 64  @ H/2
        self.dec2 = ConvBlock(base_ch + base_ch*2, base_ch)             # [64 + 128] -> 64
        # NEW: stage with e1 (high-res) skip
        self.up1  = UpCUP(base_ch, base_ch//2)                          # 64 -> 32  @ H
        self.dec1 = ConvBlock(base_ch//2 + base_ch, base_ch//2)         # [32 + 64] -> 32

        self.out_conv = nn.Conv2d(base_ch//2, out_ch, kernel_size=1)    # 32 -> 1

    def forward(self, x):
        # ---------------- Encoder ----------------
        e1 = self.enc1(x)                       # [B,64,H,W]
        e2 = self.enc2(self.pool(e1))           # [B,128,H/2,W/2]
        e3 = self.enc3(self.pool(e2))           # [B,256,H/4,W/4]

        # ---------------- Bottleneck ----------------
        b = self.bottleneck(self.pool(e3))      # [B,256,H/8,W/8]

        # ---------------- Decoder (CUP + all skips) ----------------
        d3 = self.dec3(torch.cat([self.up3(b),  e3], dim=1))   # -> [B,128,H/4,W/4]
        d2 = self.dec2(torch.cat([self.up2(d3), e2], dim=1))   # -> [B,64, H/2,W/2]
        d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))   # -> [B,32, H,  W]
        out = self.out_conv(d1)                                # -> [B,1,  H,  W]
        return out

# ==========================
# f–k LOSS (frequency–wavenumber consistency)
# ==========================
def make_rect_fk_mask(H, W, device, ky=(0.05, 0.85), kx=(0.00, 0.90)):
    """
    Simple rectangular bandpass in normalized frequency space.
    ky, kx are fractions of Nyquist in [0,1].
    """
    yy = torch.linspace(-1, 1, H, device=device).abs().view(H, 1).expand(H, W)
    xx = torch.linspace(-1, 1, W, device=device).abs().view(1, W).expand(H, W)
    m = ((yy >= ky[0]) & (yy <= ky[1]) & (xx >= kx[0]) & (xx <= kx[1])).float()
    return m.unsqueeze(0).unsqueeze(0)  # (1,1,H,W)

def fk_loss(pred, target, bandpass=None, p=1, use_log=True, window=True):
    """
    pred/target: (B,1,H,W) in [0,1]
    Compare spectra in f–k domain (2D FFT). L1 by default.
    """
    assert pred.shape == target.shape
    if window:
        H, W = pred.shape[-2:]
        wy = torch.hann_window(H, device=pred.device).view(1,1,H,1)
        wx = torch.hann_window(W, device=pred.device).view(1,1,1,W)
        win = wy * wx
        pred = pred * win
        target = target * win

    P = torch.fft.fft2(pred, norm='ortho')
    T = torch.fft.fft2(target, norm='ortho')
    Pm = torch.abs(P)
    Tm = torch.abs(T)
    if use_log:
        Pm = torch.log1p(Pm)
        Tm = torch.log1p(Tm)
    diff = (Pm - Tm).abs() if p == 1 else (Pm - Tm).pow(2)
    if bandpass is not None:
        diff = diff * bandpass.to(diff.dtype)
    return diff.mean()

# ==========================
# TRAINING
# ==========================
model = TransUNet().to(DEVICE)
huber = nn.HuberLoss(delta=1.0)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

best_val_loss = float("inf")

for epoch in range(EPOCHS):
    model.train()
    train_loss = 0.0
    for x, y in train_loader:
        x, y = x.to(DEVICE), y.to(DEVICE)
        optimizer.zero_grad()

        preds = model(x)

        # Combine Huber + f–k loss (compute f–k on clamped [0,1] for stability)
        fk_mask = make_rect_fk_mask(y.shape[2], y.shape[3], y.device, ky=FK_KY, kx=FK_KX)
        loss_huber = huber(preds, y)
        loss_fk = fk_loss(preds.clamp(0,1), y, bandpass=fk_mask, p=1, use_log=FK_USE_LOG, window=FK_USE_HANN)
        loss = loss_huber + FK_WEIGHT * loss_fk

        loss.backward()
        optimizer.step()
        train_loss += loss.item() * x.size(0)
    train_loss /= len(train_loader.dataset)

    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for x, y in val_loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            preds = model(x)
            fk_mask = make_rect_fk_mask(y.shape[2], y.shape[3], y.device, ky=FK_KY, kx=FK_KX)
            loss_huber = huber(preds, y)
            loss_fk = fk_loss(preds.clamp(0,1), y, bandpass=fk_mask, p=1, use_log=FK_USE_LOG, window=FK_USE_HANN)
            val_loss += (loss_huber + FK_WEIGHT * loss_fk).item() * x.size(0)
    val_loss /= len(val_loader.dataset)

    print(f"Epoch [{epoch+1}/{EPOCHS}] Train {train_loss:.6f} | Val {val_loss:.6f}")

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), MODEL_PATH)
        print(f"  ✅ Saved Best Model at Epoch {epoch+1}")

# ==========================
# INFERENCE
# ==========================
print("\nRunning inference on test set...")
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.eval()

os.makedirs(RESULTS_DIR, exist_ok=True)

for i, (x, y) in enumerate(test_loader):
    x = x.to(DEVICE)
    with torch.no_grad():
        pred = model(x).cpu().squeeze(0).squeeze(0).numpy()
    pred_img = (np.clip(pred, 0.0, 1.0) * 255.0).astype(np.uint8)
    Image.fromarray(pred_img).save(os.path.join(RESULTS_DIR, f"pred_{i+1}.png"))

print(f"Predictions saved in {RESULTS_DIR}")



Running on: cuda
Train: 6300, Val: 1350, Test: 1350
Epoch [1/60] Train 0.003464 | Val 0.000149
  ✅ Saved Best Model at Epoch 1
Epoch [2/60] Train 0.000120 | Val 0.000109
  ✅ Saved Best Model at Epoch 2
Epoch [3/60] Train 0.000107 | Val 0.000101
  ✅ Saved Best Model at Epoch 3
Epoch [4/60] Train 0.000097 | Val 0.000090
  ✅ Saved Best Model at Epoch 4
Epoch [5/60] Train 0.000091 | Val 0.000107
Epoch [6/60] Train 0.000087 | Val 0.000081
  ✅ Saved Best Model at Epoch 6
Epoch [7/60] Train 0.000083 | Val 0.000088
Epoch [8/60] Train 0.000074 | Val 0.000072
  ✅ Saved Best Model at Epoch 8
Epoch [9/60] Train 0.000068 | Val 0.000060
  ✅ Saved Best Model at Epoch 9
Epoch [10/60] Train 0.000064 | Val 0.000058
  ✅ Saved Best Model at Epoch 10
Epoch [11/60] Train 0.000061 | Val 0.000059
Epoch [12/60] Train 0.000060 | Val 0.000058
  ✅ Saved Best Model at Epoch 12
Epoch [13/60] Train 0.000058 | Val 0.000061
Epoch [14/60] Train 0.000057 | Val 0.000054
  ✅ Saved Best Model at Epoch 14
Epoch [15/60] Trai

In [4]:
import os
import cv2
import numpy as np
from skimage.metrics import structural_similarity as ssim
from math import log10
import shutil

# ========================== PATHS ==========================
RESULTS_DIR = r"C:\Preet\9000_paired_bscans\predictions_fk_transunet"  # predictions from TransUNet
GT_TEST_DIR = r"C:\Preet\9000_paired_bscans\ground_truth_test_fk_transunet"

# Copy ground truth test images to a dedicated folder
os.makedirs(GT_TEST_DIR, exist_ok=True)
for f in test_y:  # test_y comes from your train/val/test split
    shutil.copy(f, GT_TEST_DIR)

# ========================== FUNCTIONS ==========================
def calculate_psnr(img1, img2):
    mse = np.mean((img1 - img2) ** 2)
    if mse == 0:
        return float('inf')
    return 20 * log10(255.0 / np.sqrt(mse))

def calculate_ssim(img1, img2):
    return ssim(img1, img2, data_range=255)

# ========================== MAIN EVALUATION ==========================
psnr_values, ssim_values = [], []

pred_files = sorted([f for f in os.listdir(RESULTS_DIR) if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
gt_files   = sorted([f for f in os.listdir(GT_TEST_DIR) if f.lower().endswith(('.png', '.jpg', '.jpeg'))])

num_pairs = min(len(pred_files), len(gt_files))
if num_pairs == 0:
    print("[Error] No matching image files found in both directories.")
else:
    if len(pred_files) != len(gt_files):
        print(f"[Warning] Different number of images: Predictions={len(pred_files)}, Ground Truth={len(gt_files)}")
        print(f"Evaluating only first {num_pairs} matched pairs.")

    for i in range(num_pairs):
        pred_path = os.path.join(RESULTS_DIR, pred_files[i])
        gt_path   = os.path.join(GT_TEST_DIR, gt_files[i])

        pred_img = cv2.imread(pred_path, cv2.IMREAD_GRAYSCALE)
        gt_img   = cv2.imread(gt_path, cv2.IMREAD_GRAYSCALE)

        if pred_img is None or gt_img is None:
            print(f"[Error] Could not load: {pred_files[i]} or {gt_files[i]}")
            continue

        if pred_img.shape != gt_img.shape:
            pred_img = cv2.resize(pred_img, (gt_img.shape[1], gt_img.shape[0]))

        psnr_values.append(calculate_psnr(pred_img, gt_img))
        ssim_values.append(calculate_ssim(pred_img, gt_img))

    if psnr_values and ssim_values:
        print(f"\n---- Test Set Evaluation ----")
        print(f"SSIM: avg={np.mean(ssim_values):.4f}, min={np.min(ssim_values):.4f}, max={np.max(ssim_values):.4f}")
        print(f"PSNR: avg={np.mean(psnr_values):.2f} dB, min={np.min(psnr_values):.2f} dB, max={np.max(psnr_values):.2f} dB")



---- Test Set Evaluation ----
SSIM: avg=0.8933, min=0.7108, max=0.9979
PSNR: avg=35.34 dB, min=28.53 dB, max=49.41 dB


In [6]:
import os
import re
import numpy as np
from PIL import Image
import cv2
from math import log10
from skimage.metrics import structural_similarity as ssim

import torch
import torch.nn as nn
import torch.nn.functional as F

# =========================
# CONFIG
# =========================
FREQ_DIRS = [
    r"C:\Preet\validation dataset\png_images_650M_1083M",
    r"C:\Preet\validation dataset\png_images_700M_1167M",
    r"C:\Preet\validation dataset\png_images_800M_1333M",
    r"C:\Preet\validation dataset\png_images_850M_1416M",
    r"C:\Preet\validation dataset\png_images_900M_1500M",
]

MODEL_PATH = r"C:\Preet\9000_paired_bscans\fk_transunet++.pth"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMAGE_SIZE = (256, 256)

# =========================
# MODEL BLOCKS (same as main code)
# =========================
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True)
        )
    def forward(self, x): return self.conv(x)

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads=4, ff_dim=512, dropout=0.1):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, embed_dim)
        )
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.drop = nn.Dropout(dropout)
    def forward(self, x):
        attn_out, _ = self.attn(x, x, x)
        x = self.norm1(x + self.drop(attn_out))
        ff_out = self.ff(x)
        x = self.norm2(x + self.drop(ff_out))
        return x

class TransformerBottleneck(nn.Module):
    def __init__(self, in_ch, patch_size=16, num_layers=2, num_heads=4):
        super().__init__()
        self.patch_size = patch_size
        self.embed_dim = in_ch
        self.flatten = nn.Unfold(kernel_size=patch_size, stride=patch_size)
        self.project = nn.Linear(patch_size*patch_size*in_ch, in_ch)
        self.transformer = nn.Sequential(*[
            TransformerBlock(embed_dim=in_ch, num_heads=num_heads) for _ in range(num_layers)
        ])
        self.reconstruct = nn.Linear(in_ch, patch_size*patch_size*in_ch)
        self.fold = None
    def forward(self, x):
        B,C,H,W = x.shape
        patches = self.flatten(x).transpose(1,2)
        patches = self.project(patches)
        patches = self.transformer(patches)
        patches = self.reconstruct(patches).transpose(1,2)
        if self.fold is None:
            self.fold = nn.Fold(output_size=(H,W), kernel_size=self.patch_size, stride=self.patch_size)
        return self.fold(patches)

class UpCUP(nn.Module):
    """Conv -> PixelShuffle(2x) -> Conv"""
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.up = nn.Sequential(
            nn.Conv2d(in_ch, out_ch*4, kernel_size=3, padding=1),
            nn.PixelShuffle(2),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
            nn.GELU()
        )
    def forward(self, x): return self.up(x)

class TransUNet(nn.Module):
    def __init__(self, in_ch=1, out_ch=1, base_ch=64):
        super().__init__()
        # Encoder
        self.enc1 = ConvBlock(in_ch, base_ch)
        self.enc2 = ConvBlock(base_ch, base_ch*2)
        self.enc3 = ConvBlock(base_ch*2, base_ch*4)
        self.pool = nn.MaxPool2d(2)
        # Bottleneck
        self.bottleneck = TransformerBottleneck(base_ch*4, patch_size=16, num_layers=2, num_heads=4)
        # Decoder (CUP + skips)
        self.up3  = UpCUP(base_ch*4, base_ch*2)
        self.dec3 = ConvBlock(base_ch*2 + base_ch*4, base_ch*2)
        self.up2  = UpCUP(base_ch*2, base_ch)
        self.dec2 = ConvBlock(base_ch + base_ch*2, base_ch)
        self.up1  = UpCUP(base_ch, base_ch//2)
        self.dec1 = ConvBlock(base_ch//2 + base_ch, base_ch//2)
        self.out_conv = nn.Conv2d(base_ch//2, out_ch, kernel_size=1)
    def forward(self, x):
        e1 = self.enc1(x)                      # [B,64,H,W]
        e2 = self.enc2(self.pool(e1))          # [B,128,H/2,W/2]
        e3 = self.enc3(self.pool(e2))          # [B,256,H/4,W/4]
        b  = self.bottleneck(self.pool(e3))    # [B,256,H/8,W/8]
        d3 = self.dec3(torch.cat([self.up3(b),  e3], dim=1))
        d2 = self.dec2(torch.cat([self.up2(d3), e2], dim=1))
        d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))
        return self.out_conv(d1)

# =========================
# HELPERS
# =========================
def calculate_psnr(img1, img2):
    mse = np.mean((img1 - img2) ** 2)
    if mse == 0: return float('inf')
    return 20 * log10(255.0 / np.sqrt(mse))

def calculate_ssim(img1, img2):
    return ssim(img1, img2, data_range=255)

def numerical_sort(files):
    return sorted(files, key=lambda f: int(re.search(r'\d+', f).group()))

# =========================
# LOAD MODEL
# =========================
model = TransUNet().to(DEVICE)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.eval()

# =========================
# INFERENCE LOOP
# =========================
for freq_dir in FREQ_DIRS:
    print(f"\n---- Evaluating {os.path.basename(freq_dir)} ----")
    pred_dir = os.path.join(freq_dir, "predictions_fk_transunet")
    os.makedirs(pred_dir, exist_ok=True)

    low_files = numerical_sort([f for f in os.listdir(freq_dir) if f.endswith("_l.png")])
    high_files = numerical_sort([f for f in os.listdir(freq_dir) if f.endswith("_h.png")])

    psnr_values, ssim_values = [], []

    for i in range(len(low_files)):
        low_path = os.path.join(freq_dir, low_files[i])
        high_path = os.path.join(freq_dir, high_files[i])

        # Load LR
        lr_img = np.array(Image.open(low_path).resize(IMAGE_SIZE), dtype=np.float32) / 255.0
        lr_tensor = torch.tensor(lr_img).unsqueeze(0).unsqueeze(0).to(DEVICE)

        # Predict
        with torch.no_grad():
            pred = model(lr_tensor).cpu().squeeze().numpy()

        # Scale to 0-255
        pred_img = (np.clip(pred, 0.0, 1.0) * 255.0).astype(np.uint8)

        # Save prediction
        pred_path = os.path.join(pred_dir, f"pred_{i+1}.png")
        Image.fromarray(pred_img).save(pred_path)

        # Load GT
        gt_img = cv2.imread(high_path, cv2.IMREAD_GRAYSCALE)
        if pred_img.shape != gt_img.shape:
            pred_img = cv2.resize(pred_img, (gt_img.shape[1], gt_img.shape[0]))

        # Metrics
        psnr_values.append(calculate_psnr(pred_img, gt_img))
        ssim_values.append(calculate_ssim(pred_img, gt_img))

    if psnr_values and ssim_values:
        print(f"SSIM: avg={np.mean(ssim_values):.4f}, min={np.min(ssim_values):.4f}, max={np.max(ssim_values):.4f}")
        print(f"PSNR: avg={np.mean(psnr_values):.2f} dB, min={np.min(psnr_values):.2f} dB, max={np.max(psnr_values):.2f} dB")
    else:
        print("[Error] No valid pairs processed.")



---- Evaluating png_images_650M_1083M ----
SSIM: avg=0.9504, min=0.9216, max=0.9624
PSNR: avg=35.95 dB, min=31.54 dB, max=37.29 dB

---- Evaluating png_images_700M_1167M ----
SSIM: avg=0.9852, min=0.7560, max=0.9900
PSNR: avg=38.55 dB, min=29.26 dB, max=39.93 dB

---- Evaluating png_images_800M_1333M ----
SSIM: avg=0.9902, min=0.9735, max=0.9938
PSNR: avg=39.53 dB, min=34.42 dB, max=40.78 dB

---- Evaluating png_images_850M_1416M ----
SSIM: avg=0.9580, min=0.9317, max=0.9676
PSNR: avg=34.91 dB, min=31.92 dB, max=36.29 dB

---- Evaluating png_images_900M_1500M ----
SSIM: avg=0.9242, min=0.8702, max=0.9424
PSNR: avg=32.61 dB, min=29.65 dB, max=34.51 dB


In [8]:
import os
import re
import numpy as np
from PIL import Image
import cv2
from math import log10
from skimage.metrics import structural_similarity as ssim

import torch
import torch.nn as nn
import torch.nn.functional as F

# =========================
# CONFIG
# =========================
DATASET_DIR = r"C:\Preet\validation dataset\png_images_650M_1083M_1800M"
MODEL_PATH = r"C:\Preet\9000_paired_bscans\fk_transunet++.pth"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMAGE_SIZE = (256, 256)

# =========================
# MODEL BLOCKS (same as main code)
# =========================
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True)
        )
    def forward(self, x): return self.conv(x)

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads=4, ff_dim=512, dropout=0.1):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, embed_dim)
        )
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.drop = nn.Dropout(dropout)
    def forward(self, x):
        attn_out, _ = self.attn(x, x, x)
        x = self.norm1(x + self.drop(attn_out))
        ff_out = self.ff(x)
        x = self.norm2(x + self.drop(ff_out))
        return x

class TransformerBottleneck(nn.Module):
    def __init__(self, in_ch, patch_size=16, num_layers=2, num_heads=4):
        super().__init__()
        self.patch_size = patch_size
        self.embed_dim = in_ch
        self.flatten = nn.Unfold(kernel_size=patch_size, stride=patch_size)
        self.project = nn.Linear(patch_size*patch_size*in_ch, in_ch)
        self.transformer = nn.Sequential(*[
            TransformerBlock(embed_dim=in_ch, num_heads=num_heads) for _ in range(num_layers)
        ])
        self.reconstruct = nn.Linear(in_ch, patch_size*patch_size*in_ch)
        self.fold = None
    def forward(self, x):
        B,C,H,W = x.shape
        patches = self.flatten(x).transpose(1,2)
        patches = self.project(patches)
        patches = self.transformer(patches)
        patches = self.reconstruct(patches).transpose(1,2)
        if self.fold is None:
            self.fold = nn.Fold(output_size=(H,W), kernel_size=self.patch_size, stride=self.patch_size)
        return self.fold(patches)

class UpCUP(nn.Module):
    """Conv -> PixelShuffle(2x) -> Conv"""
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.up = nn.Sequential(
            nn.Conv2d(in_ch, out_ch*4, kernel_size=3, padding=1),
            nn.PixelShuffle(2),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
            nn.GELU()
        )
    def forward(self, x): return self.up(x)

class TransUNet(nn.Module):
    def __init__(self, in_ch=1, out_ch=1, base_ch=64):
        super().__init__()
        # Encoder
        self.enc1 = ConvBlock(in_ch, base_ch)
        self.enc2 = ConvBlock(base_ch, base_ch*2)
        self.enc3 = ConvBlock(base_ch*2, base_ch*4)
        self.pool = nn.MaxPool2d(2)
        # Bottleneck
        self.bottleneck = TransformerBottleneck(base_ch*4, patch_size=16, num_layers=2, num_heads=4)
        # Decoder (CUP + skips)
        self.up3  = UpCUP(base_ch*4, base_ch*2)
        self.dec3 = ConvBlock(base_ch*2 + base_ch*4, base_ch*2)
        self.up2  = UpCUP(base_ch*2, base_ch)
        self.dec2 = ConvBlock(base_ch + base_ch*2, base_ch)
        self.up1  = UpCUP(base_ch, base_ch//2)
        self.dec1 = ConvBlock(base_ch//2 + base_ch, base_ch//2)
        self.out_conv = nn.Conv2d(base_ch//2, out_ch, kernel_size=1)
    def forward(self, x):
        e1 = self.enc1(x)                      # [B,64,H,W]
        e2 = self.enc2(self.pool(e1))          # [B,128,H/2,W/2]
        e3 = self.enc3(self.pool(e2))          # [B,256,H/4,W/4]
        b  = self.bottleneck(self.pool(e3))    # [B,256,H/8,W/8]
        d3 = self.dec3(torch.cat([self.up3(b),  e3], dim=1))
        d2 = self.dec2(torch.cat([self.up2(d3), e2], dim=1))
        d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))
        return self.out_conv(d1)


# =========================
# HELPERS
# =========================
def calculate_psnr(img1, img2):
    img1, img2 = img1.astype(np.float32), img2.astype(np.float32)
    mse = np.mean((img1 - img2) ** 2)
    if mse == 0: return float('inf')
    return 20 * log10(255.0 / np.sqrt(mse))

def calculate_ssim(img1, img2):
    return ssim(img1, img2, data_range=255)

def numerical_sort(files):
    return sorted(files, key=lambda f: int(re.search(r'\d+', f).group()))

def run_model(img, model):
    """Runs inference with proper scaling and output normalization"""
    tensor = torch.tensor(img, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        pred = model(tensor).cpu().numpy()
    pred = np.squeeze(pred)
    # scale back to [0,255]
    pred = (pred - pred.min()) / (pred.max() - pred.min() + 1e-8)
    pred = (pred * 255.0).astype(np.uint8)
    return pred

# =========================
# LOAD MODEL
# =========================
model = TransUNet().to(DEVICE)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.eval()

# =========================
# INFERENCE 2-STAGE
# =========================
files_650 = numerical_sort([f for f in os.listdir(DATASET_DIR) if f.endswith("650_bscan.png")])
files_1083 = numerical_sort([f for f in os.listdir(DATASET_DIR) if f.endswith("1083_bscan.png")])
files_1800 = numerical_sort([f for f in os.listdir(DATASET_DIR) if f.endswith("1800_bscan.png")])

psnr_stage1, ssim_stage1 = [], []
psnr_stage2, ssim_stage2 = [], []

pred_dir1 = os.path.join(DATASET_DIR, "predictions_stage1_fk_transunet")
pred_dir2 = os.path.join(DATASET_DIR, "predictions_stage2_fk_transunet")
os.makedirs(pred_dir1, exist_ok=True)
os.makedirs(pred_dir2, exist_ok=True)

for i in range(len(files_650)):
    # ---- Stage 1: 650 -> pred -> compare with 1083 ----
    lr_img = np.array(Image.open(os.path.join(DATASET_DIR, files_650[i])).resize(IMAGE_SIZE), dtype=np.float32) / 255.0
    gt_1083 = cv2.imread(os.path.join(DATASET_DIR, files_1083[i]), cv2.IMREAD_GRAYSCALE)

    pred_1083 = run_model(lr_img, model)
    Image.fromarray(pred_1083).save(os.path.join(pred_dir1, f"pred1_{i+1}.png"))

    if pred_1083.shape != gt_1083.shape:
        pred_1083 = cv2.resize(pred_1083, (gt_1083.shape[1], gt_1083.shape[0]))

    psnr_stage1.append(calculate_psnr(pred_1083, gt_1083))
    ssim_stage1.append(calculate_ssim(pred_1083, gt_1083))

    # ---- Stage 2: pred_1083 -> pred -> compare with 1800 ----
    gt_1800 = cv2.imread(os.path.join(DATASET_DIR, files_1800[i]), cv2.IMREAD_GRAYSCALE)
    pred_1083_resized = cv2.resize(pred_1083, IMAGE_SIZE).astype(np.float32) / 255.0
    pred_1800 = run_model(pred_1083_resized, model)
    Image.fromarray(pred_1800).save(os.path.join(pred_dir2, f"pred2_{i+1}.png"))

    if pred_1800.shape != gt_1800.shape:
        pred_1800 = cv2.resize(pred_1800, (gt_1800.shape[1], gt_1800.shape[0]))

    psnr_stage2.append(calculate_psnr(pred_1800, gt_1800))
    ssim_stage2.append(calculate_ssim(pred_1800, gt_1800))

# =========================
# RESULTS
# =========================
print("\n---- Stage 1: 650 → pred → compare with 1083 ----")
print(f"SSIM: avg={np.mean(ssim_stage1):.4f}, min={np.min(ssim_stage1):.4f}, max={np.max(ssim_stage1):.4f}")
print(f"PSNR: avg={np.mean(psnr_stage1):.2f} dB, min={np.min(psnr_stage1):.2f} dB, max={np.max(psnr_stage1):.2f} dB")

print("\n---- Stage 2: pred_1083 → pred → compare with 1800 ----")
print(f"SSIM: avg={np.mean(ssim_stage2):.4f}, min={np.min(ssim_stage2):.4f}, max={np.max(ssim_stage2):.4f}")
print(f"PSNR: avg={np.mean(psnr_stage2):.2f} dB, min={np.min(psnr_stage2):.2f} dB, max={np.max(psnr_stage2):.2f} dB")



---- Stage 1: 650 → pred → compare with 1083 ----
SSIM: avg=0.9478, min=0.9217, max=0.9610
PSNR: avg=27.36 dB, min=26.05 dB, max=27.78 dB

---- Stage 2: pred_1083 → pred → compare with 1800 ----
SSIM: avg=0.9347, min=0.8929, max=0.9553
PSNR: avg=27.32 dB, min=23.13 dB, max=28.62 dB


In [10]:
import os
import re
import numpy as np
from PIL import Image
import cv2
from math import log10
from skimage.metrics import structural_similarity as ssim

import torch
import torch.nn as nn
import torch.nn.functional as F

# =========================
# CONFIG
# =========================
DATASET_DIR = r"C:\Preet\validation dataset\png_images_900M_1500M_2500M"
MODEL_PATH = r"C:\Preet\9000_paired_bscans\fk_transunet++.pth"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMAGE_SIZE = (256, 256)

# =========================
# MODEL BLOCKS (same as main code)
# =========================
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True)
        )
    def forward(self, x): return self.conv(x)

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads=4, ff_dim=512, dropout=0.1):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, embed_dim)
        )
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.drop = nn.Dropout(dropout)
    def forward(self, x):
        attn_out, _ = self.attn(x, x, x)
        x = self.norm1(x + self.drop(attn_out))
        ff_out = self.ff(x)
        x = self.norm2(x + self.drop(ff_out))
        return x

class TransformerBottleneck(nn.Module):
    def __init__(self, in_ch, patch_size=16, num_layers=2, num_heads=4):
        super().__init__()
        self.patch_size = patch_size
        self.embed_dim = in_ch
        self.flatten = nn.Unfold(kernel_size=patch_size, stride=patch_size)
        self.project = nn.Linear(patch_size*patch_size*in_ch, in_ch)
        self.transformer = nn.Sequential(*[
            TransformerBlock(embed_dim=in_ch, num_heads=num_heads) for _ in range(num_layers)
        ])
        self.reconstruct = nn.Linear(in_ch, patch_size*patch_size*in_ch)
        self.fold = None
    def forward(self, x):
        B,C,H,W = x.shape
        patches = self.flatten(x).transpose(1,2)
        patches = self.project(patches)
        patches = self.transformer(patches)
        patches = self.reconstruct(patches).transpose(1,2)
        if self.fold is None:
            self.fold = nn.Fold(output_size=(H,W), kernel_size=self.patch_size, stride=self.patch_size)
        return self.fold(patches)

class UpCUP(nn.Module):
    """Conv -> PixelShuffle(2x) -> Conv"""
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.up = nn.Sequential(
            nn.Conv2d(in_ch, out_ch*4, kernel_size=3, padding=1),
            nn.PixelShuffle(2),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
            nn.GELU()
        )
    def forward(self, x): return self.up(x)

class TransUNet(nn.Module):
    def __init__(self, in_ch=1, out_ch=1, base_ch=64):
        super().__init__()
        # Encoder
        self.enc1 = ConvBlock(in_ch, base_ch)
        self.enc2 = ConvBlock(base_ch, base_ch*2)
        self.enc3 = ConvBlock(base_ch*2, base_ch*4)
        self.pool = nn.MaxPool2d(2)
        # Bottleneck
        self.bottleneck = TransformerBottleneck(base_ch*4, patch_size=16, num_layers=2, num_heads=4)
        # Decoder (CUP + skips)
        self.up3  = UpCUP(base_ch*4, base_ch*2)
        self.dec3 = ConvBlock(base_ch*2 + base_ch*4, base_ch*2)
        self.up2  = UpCUP(base_ch*2, base_ch)
        self.dec2 = ConvBlock(base_ch + base_ch*2, base_ch)
        self.up1  = UpCUP(base_ch, base_ch//2)
        self.dec1 = ConvBlock(base_ch//2 + base_ch, base_ch//2)
        self.out_conv = nn.Conv2d(base_ch//2, out_ch, kernel_size=1)
    def forward(self, x):
        e1 = self.enc1(x)                      # [B,64,H,W]
        e2 = self.enc2(self.pool(e1))          # [B,128,H/2,W/2]
        e3 = self.enc3(self.pool(e2))          # [B,256,H/4,W/4]
        b  = self.bottleneck(self.pool(e3))    # [B,256,H/8,W/8]
        d3 = self.dec3(torch.cat([self.up3(b),  e3], dim=1))
        d2 = self.dec2(torch.cat([self.up2(d3), e2], dim=1))
        d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))
        return self.out_conv(d1)


# =========================
# HELPERS
# =========================
def calculate_psnr(img1, img2):
    img1, img2 = img1.astype(np.float32), img2.astype(np.float32)
    mse = np.mean((img1 - img2) ** 2)
    if mse == 0: return float('inf')
    return 20 * log10(255.0 / np.sqrt(mse))

def calculate_ssim(img1, img2):
    return ssim(img1, img2, data_range=255)

def numerical_sort(files):
    return sorted(files, key=lambda f: int(re.search(r'\d+', f).group()))

def run_model(img, model):
    """Runs inference with proper scaling and output normalization"""
    tensor = torch.tensor(img, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        pred = model(tensor).cpu().numpy()
    pred = np.squeeze(pred)
    # scale back to [0,255]
    pred = (pred - pred.min()) / (pred.max() - pred.min() + 1e-8)
    pred = (pred * 255.0).astype(np.uint8)
    return pred

# =========================
# LOAD MODEL
# =========================
model = TransUNet().to(DEVICE)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.eval()

# =========================
# INFERENCE 2-STAGE
# =========================
files_650 = numerical_sort([f for f in os.listdir(DATASET_DIR) if f.endswith("900_bscan.png")])
files_1083 = numerical_sort([f for f in os.listdir(DATASET_DIR) if f.endswith("1500_bscan.png")])
files_1800 = numerical_sort([f for f in os.listdir(DATASET_DIR) if f.endswith("2500_bscan.png")])

psnr_stage1, ssim_stage1 = [], []
psnr_stage2, ssim_stage2 = [], []

pred_dir1 = os.path.join(DATASET_DIR, "predictions_stage1_fk_transunet")
pred_dir2 = os.path.join(DATASET_DIR, "predictions_stage2_fk_transunet")
os.makedirs(pred_dir1, exist_ok=True)
os.makedirs(pred_dir2, exist_ok=True)

for i in range(len(files_650)):
    # ---- Stage 1: 650 -> pred -> compare with 1083 ----
    lr_img = np.array(Image.open(os.path.join(DATASET_DIR, files_650[i])).resize(IMAGE_SIZE), dtype=np.float32) / 255.0
    gt_1083 = cv2.imread(os.path.join(DATASET_DIR, files_1083[i]), cv2.IMREAD_GRAYSCALE)

    pred_1083 = run_model(lr_img, model)
    Image.fromarray(pred_1083).save(os.path.join(pred_dir1, f"pred1_{i+1}.png"))

    if pred_1083.shape != gt_1083.shape:
        pred_1083 = cv2.resize(pred_1083, (gt_1083.shape[1], gt_1083.shape[0]))

    psnr_stage1.append(calculate_psnr(pred_1083, gt_1083))
    ssim_stage1.append(calculate_ssim(pred_1083, gt_1083))

    # ---- Stage 2: pred_1083 -> pred -> compare with 1800 ----
    gt_1800 = cv2.imread(os.path.join(DATASET_DIR, files_1800[i]), cv2.IMREAD_GRAYSCALE)
    pred_1083_resized = cv2.resize(pred_1083, IMAGE_SIZE).astype(np.float32) / 255.0
    pred_1800 = run_model(pred_1083_resized, model)
    Image.fromarray(pred_1800).save(os.path.join(pred_dir2, f"pred2_{i+1}.png"))

    if pred_1800.shape != gt_1800.shape:
        pred_1800 = cv2.resize(pred_1800, (gt_1800.shape[1], gt_1800.shape[0]))

    psnr_stage2.append(calculate_psnr(pred_1800, gt_1800))
    ssim_stage2.append(calculate_ssim(pred_1800, gt_1800))

# =========================
# RESULTS
# =========================
print("\n---- Stage 1: 900 → pred → compare with 1500 ----")
print(f"SSIM: avg={np.mean(ssim_stage1):.4f}, min={np.min(ssim_stage1):.4f}, max={np.max(ssim_stage1):.4f}")
print(f"PSNR: avg={np.mean(psnr_stage1):.2f} dB, min={np.min(psnr_stage1):.2f} dB, max={np.max(psnr_stage1):.2f} dB")

print("\n---- Stage 2: pred_1500 → pred → compare with 2500 ----")
print(f"SSIM: avg={np.mean(ssim_stage2):.4f}, min={np.min(ssim_stage2):.4f}, max={np.max(ssim_stage2):.4f}")
print(f"PSNR: avg={np.mean(psnr_stage2):.2f} dB, min={np.min(psnr_stage2):.2f} dB, max={np.max(psnr_stage2):.2f} dB")


---- Stage 1: 900 → pred → compare with 1500 ----
SSIM: avg=0.7180, min=0.5728, max=0.8518
PSNR: avg=17.10 dB, min=15.76 dB, max=17.82 dB

---- Stage 2: pred_1500 → pred → compare with 2500 ----
SSIM: avg=0.7242, min=0.6003, max=0.8451
PSNR: avg=15.46 dB, min=14.46 dB, max=15.99 dB


------

----------

------------

----------------

4000 paired dataset

In [1]:
import os
import math
import cv2
import numpy as np
from PIL import Image
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

# ==========================
# CONFIG
# ==========================
DATASET_DIR = r"C:\Preet\4000_paired_bscans"  # *_l.png and *_h.png paired images
IMAGE_SIZE = (256, 256)
BATCH_SIZE = 8
EPOCHS = 60
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Running on:", DEVICE)

MODEL_PATH = os.path.join(DATASET_DIR, "fk_transunet++.pth")
RESULTS_DIR = os.path.join(DATASET_DIR, "predictions_fk_transunet")
os.makedirs(RESULTS_DIR, exist_ok=True)

# f–k loss weights / options (tune if needed)
FK_WEIGHT = 0.5           # contribution of f–k term to total loss
FK_USE_LOG = True         # use log magnitude for stability
FK_USE_HANN = True        # apply 2D Hann window before FFT to reduce leakage
FK_KY = (0.05, 0.85)      # rectangular band in normalized ky (temporal freq)
FK_KX = (0.00, 0.90)      # rectangular band in normalized kx (spatial wavenumber)

# ==========================
# DATASET
# ==========================
class GPRDataset(Dataset):
    def __init__(self, x_paths, y_paths):
        self.x_paths = x_paths
        self.y_paths = y_paths

    def __len__(self):
        return len(self.x_paths)

    def __getitem__(self, idx):
        x = np.array(Image.open(self.x_paths[idx]).resize(IMAGE_SIZE), dtype=np.float32) / 255.0
        y = np.array(Image.open(self.y_paths[idx]).resize(IMAGE_SIZE), dtype=np.float32) / 255.0
        x = torch.tensor(x).unsqueeze(0)  # (1,H,W)
        y = torch.tensor(y).unsqueeze(0)
        return x, y

def load_data(dataset_dir):
    low_paths, high_paths = [], []
    for file in os.listdir(dataset_dir):
        if file.endswith("_l.png"):
            low_path = os.path.join(dataset_dir, file)
            high_path = os.path.join(dataset_dir, file.replace("_l.png", "_h.png"))
            if os.path.exists(high_path):
                low_paths.append(low_path)
                high_paths.append(high_path)
    return low_paths, high_paths

all_x, all_y = load_data(DATASET_DIR)

# Split 70/15/15
train_x, temp_x, train_y, temp_y = train_test_split(all_x, all_y, test_size=0.30, random_state=42)
val_x, test_x, val_y, test_y = train_test_split(temp_x, temp_y, test_size=0.50, random_state=42)

print(f"Train: {len(train_x)}, Val: {len(val_x)}, Test: {len(test_x)}")

train_loader = DataLoader(GPRDataset(train_x, train_y), batch_size=BATCH_SIZE, shuffle=True)
val_loader   = DataLoader(GPRDataset(val_x, val_y), batch_size=BATCH_SIZE, shuffle=False)
test_loader  = DataLoader(GPRDataset(test_x, test_y), batch_size=1, shuffle=False)

# ==========================
# MODEL BLOCKS
# ==========================
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.conv(x)

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads=4, ff_dim=512, dropout=0.1):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, embed_dim)
        )
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.drop = nn.Dropout(dropout)

    def forward(self, x):
        attn_out, _ = self.attn(x, x, x)
        x = self.norm1(x + self.drop(attn_out))
        ff_out = self.ff(x)
        x = self.norm2(x + self.drop(ff_out))
        return x

class TransformerBottleneck(nn.Module):
    def __init__(self, in_ch, patch_size=16, num_layers=2, num_heads=4):
        super().__init__()
        self.patch_size = patch_size
        self.embed_dim = in_ch

        self.flatten = nn.Unfold(kernel_size=patch_size, stride=patch_size)
        self.project = nn.Linear(patch_size * patch_size * in_ch, in_ch)

        self.transformer = nn.Sequential(*[
            TransformerBlock(embed_dim=in_ch, num_heads=num_heads) for _ in range(num_layers)
        ])

        self.reconstruct = nn.Linear(in_ch, patch_size * patch_size * in_ch)
        self.fold = None  # initialized at runtime

    def forward(self, x):
        B, C, H, W = x.shape
        patches = self.flatten(x).transpose(1, 2)  # (B, N, patch_dim)
        patches = self.project(patches)
        patches = self.transformer(patches)
        patches = self.reconstruct(patches).transpose(1, 2)
        if self.fold is None:
            self.fold = nn.Fold(output_size=(H, W),
                                kernel_size=self.patch_size,
                                stride=self.patch_size)
        x_reconstructed = self.fold(patches)
        return x_reconstructed

# ---- NEW: CUP upsampler block ----
class UpCUP(nn.Module):
    """Detail-preserving upsampler: conv -> pixelshuffle(2x) -> conv"""
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.up = nn.Sequential(
            nn.Conv2d(in_ch, out_ch * 4, kernel_size=3, padding=1),
            nn.PixelShuffle(2),                         # (H,W) -> (2H,2W)
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
            nn.GELU()
        )
    def forward(self, x):
        return self.up(x)

class TransUNet(nn.Module):
    def __init__(self, in_ch=1, out_ch=1, base_ch=64):
        super().__init__()
        # ================= Encoder =================
        self.enc1 = ConvBlock(in_ch, base_ch)            # 1 -> 64
        self.enc2 = ConvBlock(base_ch, base_ch*2)        # 64 -> 128
        self.enc3 = ConvBlock(base_ch*2, base_ch*4)      # 128 -> 256
        self.pool = nn.MaxPool2d(2)

        # ================= Bottleneck =================
        self.bottleneck = TransformerBottleneck(base_ch*4, patch_size=16, num_layers=2, num_heads=4)

        # ================= Decoder (CUP) =================
        # stage with e3 skip
        self.up3  = UpCUP(base_ch*4, base_ch*2)                         # 256 -> 128 @ H/4
        self.dec3 = ConvBlock(base_ch*2 + base_ch*4, base_ch*2)         # [128 + 256] -> 128
        # stage with e2 skip
        self.up2  = UpCUP(base_ch*2, base_ch)                           # 128 -> 64  @ H/2
        self.dec2 = ConvBlock(base_ch + base_ch*2, base_ch)             # [64 + 128] -> 64
        # NEW: stage with e1 (high-res) skip
        self.up1  = UpCUP(base_ch, base_ch//2)                          # 64 -> 32  @ H
        self.dec1 = ConvBlock(base_ch//2 + base_ch, base_ch//2)         # [32 + 64] -> 32

        self.out_conv = nn.Conv2d(base_ch//2, out_ch, kernel_size=1)    # 32 -> 1

    def forward(self, x):
        # ---------------- Encoder ----------------
        e1 = self.enc1(x)                       # [B,64,H,W]
        e2 = self.enc2(self.pool(e1))           # [B,128,H/2,W/2]
        e3 = self.enc3(self.pool(e2))           # [B,256,H/4,W/4]

        # ---------------- Bottleneck ----------------
        b = self.bottleneck(self.pool(e3))      # [B,256,H/8,W/8]

        # ---------------- Decoder (CUP + all skips) ----------------
        d3 = self.dec3(torch.cat([self.up3(b),  e3], dim=1))   # -> [B,128,H/4,W/4]
        d2 = self.dec2(torch.cat([self.up2(d3), e2], dim=1))   # -> [B,64, H/2,W/2]
        d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))   # -> [B,32, H,  W]
        out = self.out_conv(d1)                                # -> [B,1,  H,  W]
        return out

# ==========================
# f–k LOSS (frequency–wavenumber consistency)
# ==========================
def make_rect_fk_mask(H, W, device, ky=(0.05, 0.85), kx=(0.00, 0.90)):
    """
    Simple rectangular bandpass in normalized frequency space.
    ky, kx are fractions of Nyquist in [0,1].
    """
    yy = torch.linspace(-1, 1, H, device=device).abs().view(H, 1).expand(H, W)
    xx = torch.linspace(-1, 1, W, device=device).abs().view(1, W).expand(H, W)
    m = ((yy >= ky[0]) & (yy <= ky[1]) & (xx >= kx[0]) & (xx <= kx[1])).float()
    return m.unsqueeze(0).unsqueeze(0)  # (1,1,H,W)

def fk_loss(pred, target, bandpass=None, p=1, use_log=True, window=True):
    """
    pred/target: (B,1,H,W) in [0,1]
    Compare spectra in f–k domain (2D FFT). L1 by default.
    """
    assert pred.shape == target.shape
    if window:
        H, W = pred.shape[-2:]
        wy = torch.hann_window(H, device=pred.device).view(1,1,H,1)
        wx = torch.hann_window(W, device=pred.device).view(1,1,1,W)
        win = wy * wx
        pred = pred * win
        target = target * win

    P = torch.fft.fft2(pred, norm='ortho')
    T = torch.fft.fft2(target, norm='ortho')
    Pm = torch.abs(P)
    Tm = torch.abs(T)
    if use_log:
        Pm = torch.log1p(Pm)
        Tm = torch.log1p(Tm)
    diff = (Pm - Tm).abs() if p == 1 else (Pm - Tm).pow(2)
    if bandpass is not None:
        diff = diff * bandpass.to(diff.dtype)
    return diff.mean()

# ==========================
# TRAINING
# ==========================
model = TransUNet().to(DEVICE)
huber = nn.HuberLoss(delta=1.0)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

best_val_loss = float("inf")

for epoch in range(EPOCHS):
    model.train()
    train_loss = 0.0
    for x, y in train_loader:
        x, y = x.to(DEVICE), y.to(DEVICE)
        optimizer.zero_grad()

        preds = model(x)

        # Combine Huber + f–k loss (compute f–k on clamped [0,1] for stability)
        fk_mask = make_rect_fk_mask(y.shape[2], y.shape[3], y.device, ky=FK_KY, kx=FK_KX)
        loss_huber = huber(preds, y)
        loss_fk = fk_loss(preds.clamp(0,1), y, bandpass=fk_mask, p=1, use_log=FK_USE_LOG, window=FK_USE_HANN)
        loss = loss_huber + FK_WEIGHT * loss_fk

        loss.backward()
        optimizer.step()
        train_loss += loss.item() * x.size(0)
    train_loss /= len(train_loader.dataset)

    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for x, y in val_loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            preds = model(x)
            fk_mask = make_rect_fk_mask(y.shape[2], y.shape[3], y.device, ky=FK_KY, kx=FK_KX)
            loss_huber = huber(preds, y)
            loss_fk = fk_loss(preds.clamp(0,1), y, bandpass=fk_mask, p=1, use_log=FK_USE_LOG, window=FK_USE_HANN)
            val_loss += (loss_huber + FK_WEIGHT * loss_fk).item() * x.size(0)
    val_loss /= len(val_loader.dataset)

    print(f"Epoch [{epoch+1}/{EPOCHS}] Train {train_loss:.6f} | Val {val_loss:.6f}")

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), MODEL_PATH)
        print(f"  ✅ Saved Best Model at Epoch {epoch+1}")

# ==========================
# INFERENCE
# ==========================
print("\nRunning inference on test set...")
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.eval()

os.makedirs(RESULTS_DIR, exist_ok=True)

for i, (x, y) in enumerate(test_loader):
    x = x.to(DEVICE)
    with torch.no_grad():
        pred = model(x).cpu().squeeze(0).squeeze(0).numpy()
    pred_img = (np.clip(pred, 0.0, 1.0) * 255.0).astype(np.uint8)
    Image.fromarray(pred_img).save(os.path.join(RESULTS_DIR, f"pred_{i+1}.png"))

print(f"Predictions saved in {RESULTS_DIR}")



Running on: cuda
Train: 2800, Val: 600, Test: 600
Epoch [1/60] Train 0.005991 | Val 0.000253
  ✅ Saved Best Model at Epoch 1
Epoch [2/60] Train 0.000200 | Val 0.000150
  ✅ Saved Best Model at Epoch 2
Epoch [3/60] Train 0.000142 | Val 0.000131
  ✅ Saved Best Model at Epoch 3
Epoch [4/60] Train 0.000121 | Val 0.000116
  ✅ Saved Best Model at Epoch 4
Epoch [5/60] Train 0.000113 | Val 0.000107
  ✅ Saved Best Model at Epoch 5
Epoch [6/60] Train 0.000105 | Val 0.000102
  ✅ Saved Best Model at Epoch 6
Epoch [7/60] Train 0.000102 | Val 0.000097
  ✅ Saved Best Model at Epoch 7
Epoch [8/60] Train 0.000097 | Val 0.000118
Epoch [9/60] Train 0.000093 | Val 0.000093
  ✅ Saved Best Model at Epoch 9
Epoch [10/60] Train 0.000092 | Val 0.000089
  ✅ Saved Best Model at Epoch 10
Epoch [11/60] Train 0.000089 | Val 0.000088
  ✅ Saved Best Model at Epoch 11
Epoch [12/60] Train 0.000088 | Val 0.000099
Epoch [13/60] Train 0.000082 | Val 0.000085
  ✅ Saved Best Model at Epoch 13
Epoch [14/60] Train 0.000080 | V

In [2]:
import os
import cv2
import numpy as np
from skimage.metrics import structural_similarity as ssim
from math import log10
import shutil

# ========================== PATHS ==========================
RESULTS_DIR = r"C:\Preet\4000_paired_bscans\predictions_fk_transunet"  # predictions from TransUNet
GT_TEST_DIR = r"C:\Preet\4000_paired_bscans\ground_truth_test_fk_transunet"

# Copy ground truth test images to a dedicated folder
os.makedirs(GT_TEST_DIR, exist_ok=True)
for f in test_y:  # test_y comes from your train/val/test split
    shutil.copy(f, GT_TEST_DIR)

# ========================== FUNCTIONS ==========================
def calculate_psnr(img1, img2):
    mse = np.mean((img1 - img2) ** 2)
    if mse == 0:
        return float('inf')
    return 20 * log10(255.0 / np.sqrt(mse))

def calculate_ssim(img1, img2):
    return ssim(img1, img2, data_range=255)

# ========================== MAIN EVALUATION ==========================
psnr_values, ssim_values = [], []

pred_files = sorted([f for f in os.listdir(RESULTS_DIR) if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
gt_files   = sorted([f for f in os.listdir(GT_TEST_DIR) if f.lower().endswith(('.png', '.jpg', '.jpeg'))])

num_pairs = min(len(pred_files), len(gt_files))
if num_pairs == 0:
    print("[Error] No matching image files found in both directories.")
else:
    if len(pred_files) != len(gt_files):
        print(f"[Warning] Different number of images: Predictions={len(pred_files)}, Ground Truth={len(gt_files)}")
        print(f"Evaluating only first {num_pairs} matched pairs.")

    for i in range(num_pairs):
        pred_path = os.path.join(RESULTS_DIR, pred_files[i])
        gt_path   = os.path.join(GT_TEST_DIR, gt_files[i])

        pred_img = cv2.imread(pred_path, cv2.IMREAD_GRAYSCALE)
        gt_img   = cv2.imread(gt_path, cv2.IMREAD_GRAYSCALE)

        if pred_img is None or gt_img is None:
            print(f"[Error] Could not load: {pred_files[i]} or {gt_files[i]}")
            continue

        if pred_img.shape != gt_img.shape:
            pred_img = cv2.resize(pred_img, (gt_img.shape[1], gt_img.shape[0]))

        psnr_values.append(calculate_psnr(pred_img, gt_img))
        ssim_values.append(calculate_ssim(pred_img, gt_img))

    if psnr_values and ssim_values:
        print(f"\n---- Test Set Evaluation ----")
        print(f"SSIM: avg={np.mean(ssim_values):.4f}, min={np.min(ssim_values):.4f}, max={np.max(ssim_values):.4f}")
        print(f"PSNR: avg={np.mean(psnr_values):.2f} dB, min={np.min(psnr_values):.2f} dB, max={np.max(psnr_values):.2f} dB")



---- Test Set Evaluation ----
SSIM: avg=0.8918, min=0.7129, max=0.9971
PSNR: avg=35.42 dB, min=28.93 dB, max=51.88 dB


In [3]:
import os
import re
import numpy as np
from PIL import Image
import cv2
from math import log10
from skimage.metrics import structural_similarity as ssim

import torch
import torch.nn as nn
import torch.nn.functional as F

# =========================
# CONFIG
# =========================
FREQ_DIRS = [
    r"C:\Preet\validation dataset_4000\png_images_650M_1083M",
    r"C:\Preet\validation dataset_4000\png_images_700M_1167M",
    r"C:\Preet\validation dataset_4000\png_images_800M_1333M",
    r"C:\Preet\validation dataset_4000\png_images_850M_1416M",
    r"C:\Preet\validation dataset_4000\png_images_900M_1500M",
]

MODEL_PATH = r"C:\Preet\4000_paired_bscans\fk_transunet++.pth"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMAGE_SIZE = (256, 256)

# =========================
# MODEL BLOCKS (same as main code)
# =========================
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True)
        )
    def forward(self, x): return self.conv(x)

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads=4, ff_dim=512, dropout=0.1):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, embed_dim)
        )
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.drop = nn.Dropout(dropout)
    def forward(self, x):
        attn_out, _ = self.attn(x, x, x)
        x = self.norm1(x + self.drop(attn_out))
        ff_out = self.ff(x)
        x = self.norm2(x + self.drop(ff_out))
        return x

class TransformerBottleneck(nn.Module):
    def __init__(self, in_ch, patch_size=16, num_layers=2, num_heads=4):
        super().__init__()
        self.patch_size = patch_size
        self.embed_dim = in_ch
        self.flatten = nn.Unfold(kernel_size=patch_size, stride=patch_size)
        self.project = nn.Linear(patch_size*patch_size*in_ch, in_ch)
        self.transformer = nn.Sequential(*[
            TransformerBlock(embed_dim=in_ch, num_heads=num_heads) for _ in range(num_layers)
        ])
        self.reconstruct = nn.Linear(in_ch, patch_size*patch_size*in_ch)
        self.fold = None
    def forward(self, x):
        B,C,H,W = x.shape
        patches = self.flatten(x).transpose(1,2)
        patches = self.project(patches)
        patches = self.transformer(patches)
        patches = self.reconstruct(patches).transpose(1,2)
        if self.fold is None:
            self.fold = nn.Fold(output_size=(H,W), kernel_size=self.patch_size, stride=self.patch_size)
        return self.fold(patches)

class UpCUP(nn.Module):
    """Conv -> PixelShuffle(2x) -> Conv"""
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.up = nn.Sequential(
            nn.Conv2d(in_ch, out_ch*4, kernel_size=3, padding=1),
            nn.PixelShuffle(2),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
            nn.GELU()
        )
    def forward(self, x): return self.up(x)

class TransUNet(nn.Module):
    def __init__(self, in_ch=1, out_ch=1, base_ch=64):
        super().__init__()
        # Encoder
        self.enc1 = ConvBlock(in_ch, base_ch)
        self.enc2 = ConvBlock(base_ch, base_ch*2)
        self.enc3 = ConvBlock(base_ch*2, base_ch*4)
        self.pool = nn.MaxPool2d(2)
        # Bottleneck
        self.bottleneck = TransformerBottleneck(base_ch*4, patch_size=16, num_layers=2, num_heads=4)
        # Decoder (CUP + skips)
        self.up3  = UpCUP(base_ch*4, base_ch*2)
        self.dec3 = ConvBlock(base_ch*2 + base_ch*4, base_ch*2)
        self.up2  = UpCUP(base_ch*2, base_ch)
        self.dec2 = ConvBlock(base_ch + base_ch*2, base_ch)
        self.up1  = UpCUP(base_ch, base_ch//2)
        self.dec1 = ConvBlock(base_ch//2 + base_ch, base_ch//2)
        self.out_conv = nn.Conv2d(base_ch//2, out_ch, kernel_size=1)
    def forward(self, x):
        e1 = self.enc1(x)                      # [B,64,H,W]
        e2 = self.enc2(self.pool(e1))          # [B,128,H/2,W/2]
        e3 = self.enc3(self.pool(e2))          # [B,256,H/4,W/4]
        b  = self.bottleneck(self.pool(e3))    # [B,256,H/8,W/8]
        d3 = self.dec3(torch.cat([self.up3(b),  e3], dim=1))
        d2 = self.dec2(torch.cat([self.up2(d3), e2], dim=1))
        d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))
        return self.out_conv(d1)

# =========================
# HELPERS
# =========================
def calculate_psnr(img1, img2):
    mse = np.mean((img1 - img2) ** 2)
    if mse == 0: return float('inf')
    return 20 * log10(255.0 / np.sqrt(mse))

def calculate_ssim(img1, img2):
    return ssim(img1, img2, data_range=255)

def numerical_sort(files):
    return sorted(files, key=lambda f: int(re.search(r'\d+', f).group()))

# =========================
# LOAD MODEL
# =========================
model = TransUNet().to(DEVICE)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.eval()

# =========================
# INFERENCE LOOP
# =========================
for freq_dir in FREQ_DIRS:
    print(f"\n---- Evaluating {os.path.basename(freq_dir)} ----")
    pred_dir = os.path.join(freq_dir, "predictions_fk_transunet")
    os.makedirs(pred_dir, exist_ok=True)

    low_files = numerical_sort([f for f in os.listdir(freq_dir) if f.endswith("_l.png")])
    high_files = numerical_sort([f for f in os.listdir(freq_dir) if f.endswith("_h.png")])

    psnr_values, ssim_values = [], []

    for i in range(len(low_files)):
        low_path = os.path.join(freq_dir, low_files[i])
        high_path = os.path.join(freq_dir, high_files[i])

        # Load LR
        lr_img = np.array(Image.open(low_path).resize(IMAGE_SIZE), dtype=np.float32) / 255.0
        lr_tensor = torch.tensor(lr_img).unsqueeze(0).unsqueeze(0).to(DEVICE)

        # Predict
        with torch.no_grad():
            pred = model(lr_tensor).cpu().squeeze().numpy()

        # Scale to 0-255
        pred_img = (np.clip(pred, 0.0, 1.0) * 255.0).astype(np.uint8)

        # Save prediction
        pred_path = os.path.join(pred_dir, f"pred_{i+1}.png")
        Image.fromarray(pred_img).save(pred_path)

        # Load GT
        gt_img = cv2.imread(high_path, cv2.IMREAD_GRAYSCALE)
        if pred_img.shape != gt_img.shape:
            pred_img = cv2.resize(pred_img, (gt_img.shape[1], gt_img.shape[0]))

        # Metrics
        psnr_values.append(calculate_psnr(pred_img, gt_img))
        ssim_values.append(calculate_ssim(pred_img, gt_img))

    if psnr_values and ssim_values:
        print(f"SSIM: avg={np.mean(ssim_values):.4f}, min={np.min(ssim_values):.4f}, max={np.max(ssim_values):.4f}")
        print(f"PSNR: avg={np.mean(psnr_values):.2f} dB, min={np.min(psnr_values):.2f} dB, max={np.max(psnr_values):.2f} dB")
    else:
        print("[Error] No valid pairs processed.")



---- Evaluating png_images_650M_1083M ----
SSIM: avg=0.9480, min=0.9125, max=0.9604
PSNR: avg=35.70 dB, min=31.81 dB, max=37.01 dB

---- Evaluating png_images_700M_1167M ----
SSIM: avg=0.9816, min=0.7558, max=0.9873
PSNR: avg=37.80 dB, min=29.25 dB, max=38.87 dB

---- Evaluating png_images_800M_1333M ----
SSIM: avg=0.9896, min=0.9652, max=0.9935
PSNR: avg=38.79 dB, min=33.71 dB, max=39.78 dB

---- Evaluating png_images_850M_1416M ----
SSIM: avg=0.9661, min=0.9283, max=0.9761
PSNR: avg=34.76 dB, min=31.21 dB, max=35.74 dB

---- Evaluating png_images_900M_1500M ----
SSIM: avg=0.9314, min=0.8686, max=0.9496
PSNR: avg=33.14 dB, min=29.57 dB, max=34.01 dB


In [4]:
import os
import re
import numpy as np
from PIL import Image
import cv2
from math import log10
from skimage.metrics import structural_similarity as ssim

import torch
import torch.nn as nn
import torch.nn.functional as F

# =========================
# CONFIG
# =========================
DATASET_DIR = r"C:\Preet\validation dataset_4000\png_images_650M_1083M_1800M"
MODEL_PATH = r"C:\Preet\4000_paired_bscans\fk_transunet++.pth"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMAGE_SIZE = (256, 256)

# =========================
# MODEL BLOCKS (same as main code)
# =========================
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True)
        )
    def forward(self, x): return self.conv(x)

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads=4, ff_dim=512, dropout=0.1):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, embed_dim)
        )
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.drop = nn.Dropout(dropout)
    def forward(self, x):
        attn_out, _ = self.attn(x, x, x)
        x = self.norm1(x + self.drop(attn_out))
        ff_out = self.ff(x)
        x = self.norm2(x + self.drop(ff_out))
        return x

class TransformerBottleneck(nn.Module):
    def __init__(self, in_ch, patch_size=16, num_layers=2, num_heads=4):
        super().__init__()
        self.patch_size = patch_size
        self.embed_dim = in_ch
        self.flatten = nn.Unfold(kernel_size=patch_size, stride=patch_size)
        self.project = nn.Linear(patch_size*patch_size*in_ch, in_ch)
        self.transformer = nn.Sequential(*[
            TransformerBlock(embed_dim=in_ch, num_heads=num_heads) for _ in range(num_layers)
        ])
        self.reconstruct = nn.Linear(in_ch, patch_size*patch_size*in_ch)
        self.fold = None
    def forward(self, x):
        B,C,H,W = x.shape
        patches = self.flatten(x).transpose(1,2)
        patches = self.project(patches)
        patches = self.transformer(patches)
        patches = self.reconstruct(patches).transpose(1,2)
        if self.fold is None:
            self.fold = nn.Fold(output_size=(H,W), kernel_size=self.patch_size, stride=self.patch_size)
        return self.fold(patches)

class UpCUP(nn.Module):
    """Conv -> PixelShuffle(2x) -> Conv"""
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.up = nn.Sequential(
            nn.Conv2d(in_ch, out_ch*4, kernel_size=3, padding=1),
            nn.PixelShuffle(2),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
            nn.GELU()
        )
    def forward(self, x): return self.up(x)

class TransUNet(nn.Module):
    def __init__(self, in_ch=1, out_ch=1, base_ch=64):
        super().__init__()
        # Encoder
        self.enc1 = ConvBlock(in_ch, base_ch)
        self.enc2 = ConvBlock(base_ch, base_ch*2)
        self.enc3 = ConvBlock(base_ch*2, base_ch*4)
        self.pool = nn.MaxPool2d(2)
        # Bottleneck
        self.bottleneck = TransformerBottleneck(base_ch*4, patch_size=16, num_layers=2, num_heads=4)
        # Decoder (CUP + skips)
        self.up3  = UpCUP(base_ch*4, base_ch*2)
        self.dec3 = ConvBlock(base_ch*2 + base_ch*4, base_ch*2)
        self.up2  = UpCUP(base_ch*2, base_ch)
        self.dec2 = ConvBlock(base_ch + base_ch*2, base_ch)
        self.up1  = UpCUP(base_ch, base_ch//2)
        self.dec1 = ConvBlock(base_ch//2 + base_ch, base_ch//2)
        self.out_conv = nn.Conv2d(base_ch//2, out_ch, kernel_size=1)
    def forward(self, x):
        e1 = self.enc1(x)                      # [B,64,H,W]
        e2 = self.enc2(self.pool(e1))          # [B,128,H/2,W/2]
        e3 = self.enc3(self.pool(e2))          # [B,256,H/4,W/4]
        b  = self.bottleneck(self.pool(e3))    # [B,256,H/8,W/8]
        d3 = self.dec3(torch.cat([self.up3(b),  e3], dim=1))
        d2 = self.dec2(torch.cat([self.up2(d3), e2], dim=1))
        d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))
        return self.out_conv(d1)


# =========================
# HELPERS
# =========================
def calculate_psnr(img1, img2):
    img1, img2 = img1.astype(np.float32), img2.astype(np.float32)
    mse = np.mean((img1 - img2) ** 2)
    if mse == 0: return float('inf')
    return 20 * log10(255.0 / np.sqrt(mse))

def calculate_ssim(img1, img2):
    return ssim(img1, img2, data_range=255)

def numerical_sort(files):
    return sorted(files, key=lambda f: int(re.search(r'\d+', f).group()))

def run_model(img, model):
    """Runs inference with proper scaling and output normalization"""
    tensor = torch.tensor(img, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        pred = model(tensor).cpu().numpy()
    pred = np.squeeze(pred)
    # scale back to [0,255]
    pred = (pred - pred.min()) / (pred.max() - pred.min() + 1e-8)
    pred = (pred * 255.0).astype(np.uint8)
    return pred

# =========================
# LOAD MODEL
# =========================
model = TransUNet().to(DEVICE)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.eval()

# =========================
# INFERENCE 2-STAGE
# =========================
files_650 = numerical_sort([f for f in os.listdir(DATASET_DIR) if f.endswith("650_bscan.png")])
files_1083 = numerical_sort([f for f in os.listdir(DATASET_DIR) if f.endswith("1083_bscan.png")])
files_1800 = numerical_sort([f for f in os.listdir(DATASET_DIR) if f.endswith("1800_bscan.png")])

psnr_stage1, ssim_stage1 = [], []
psnr_stage2, ssim_stage2 = [], []

pred_dir1 = os.path.join(DATASET_DIR, "predictions_stage1_fk_transunet")
pred_dir2 = os.path.join(DATASET_DIR, "predictions_stage2_fk_transunet")
os.makedirs(pred_dir1, exist_ok=True)
os.makedirs(pred_dir2, exist_ok=True)

for i in range(len(files_650)):
    # ---- Stage 1: 650 -> pred -> compare with 1083 ----
    lr_img = np.array(Image.open(os.path.join(DATASET_DIR, files_650[i])).resize(IMAGE_SIZE), dtype=np.float32) / 255.0
    gt_1083 = cv2.imread(os.path.join(DATASET_DIR, files_1083[i]), cv2.IMREAD_GRAYSCALE)

    pred_1083 = run_model(lr_img, model)
    Image.fromarray(pred_1083).save(os.path.join(pred_dir1, f"pred1_{i+1}.png"))

    if pred_1083.shape != gt_1083.shape:
        pred_1083 = cv2.resize(pred_1083, (gt_1083.shape[1], gt_1083.shape[0]))

    psnr_stage1.append(calculate_psnr(pred_1083, gt_1083))
    ssim_stage1.append(calculate_ssim(pred_1083, gt_1083))

    # ---- Stage 2: pred_1083 -> pred -> compare with 1800 ----
    gt_1800 = cv2.imread(os.path.join(DATASET_DIR, files_1800[i]), cv2.IMREAD_GRAYSCALE)
    pred_1083_resized = cv2.resize(pred_1083, IMAGE_SIZE).astype(np.float32) / 255.0
    pred_1800 = run_model(pred_1083_resized, model)
    Image.fromarray(pred_1800).save(os.path.join(pred_dir2, f"pred2_{i+1}.png"))

    if pred_1800.shape != gt_1800.shape:
        pred_1800 = cv2.resize(pred_1800, (gt_1800.shape[1], gt_1800.shape[0]))

    psnr_stage2.append(calculate_psnr(pred_1800, gt_1800))
    ssim_stage2.append(calculate_ssim(pred_1800, gt_1800))

# =========================
# RESULTS
# =========================
print("\n---- Stage 1: 650 → pred → compare with 1083 ----")
print(f"SSIM: avg={np.mean(ssim_stage1):.4f}, min={np.min(ssim_stage1):.4f}, max={np.max(ssim_stage1):.4f}")
print(f"PSNR: avg={np.mean(psnr_stage1):.2f} dB, min={np.min(psnr_stage1):.2f} dB, max={np.max(psnr_stage1):.2f} dB")

print("\n---- Stage 2: pred_1083 → pred → compare with 1800 ----")
print(f"SSIM: avg={np.mean(ssim_stage2):.4f}, min={np.min(ssim_stage2):.4f}, max={np.max(ssim_stage2):.4f}")
print(f"PSNR: avg={np.mean(psnr_stage2):.2f} dB, min={np.min(psnr_stage2):.2f} dB, max={np.max(psnr_stage2):.2f} dB")



---- Stage 1: 650 → pred → compare with 1083 ----
SSIM: avg=0.9457, min=0.9229, max=0.9597
PSNR: avg=27.56 dB, min=25.87 dB, max=28.01 dB

---- Stage 2: pred_1083 → pred → compare with 1800 ----
SSIM: avg=0.9331, min=0.8888, max=0.9526
PSNR: avg=26.33 dB, min=24.57 dB, max=26.99 dB


In [5]:
import os
import re
import numpy as np
from PIL import Image
import cv2
from math import log10
from skimage.metrics import structural_similarity as ssim

import torch
import torch.nn as nn
import torch.nn.functional as F

# =========================
# CONFIG
# =========================
DATASET_DIR = r"C:\Preet\validation dataset_4000\png_images_900M_1500M_2500M"
MODEL_PATH = r"C:\Preet\4000_paired_bscans\fk_transunet++.pth"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMAGE_SIZE = (256, 256)

# =========================
# MODEL BLOCKS (same as main code)
# =========================
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True)
        )
    def forward(self, x): return self.conv(x)

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads=4, ff_dim=512, dropout=0.1):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, embed_dim)
        )
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.drop = nn.Dropout(dropout)
    def forward(self, x):
        attn_out, _ = self.attn(x, x, x)
        x = self.norm1(x + self.drop(attn_out))
        ff_out = self.ff(x)
        x = self.norm2(x + self.drop(ff_out))
        return x

class TransformerBottleneck(nn.Module):
    def __init__(self, in_ch, patch_size=16, num_layers=2, num_heads=4):
        super().__init__()
        self.patch_size = patch_size
        self.embed_dim = in_ch
        self.flatten = nn.Unfold(kernel_size=patch_size, stride=patch_size)
        self.project = nn.Linear(patch_size*patch_size*in_ch, in_ch)
        self.transformer = nn.Sequential(*[
            TransformerBlock(embed_dim=in_ch, num_heads=num_heads) for _ in range(num_layers)
        ])
        self.reconstruct = nn.Linear(in_ch, patch_size*patch_size*in_ch)
        self.fold = None
    def forward(self, x):
        B,C,H,W = x.shape
        patches = self.flatten(x).transpose(1,2)
        patches = self.project(patches)
        patches = self.transformer(patches)
        patches = self.reconstruct(patches).transpose(1,2)
        if self.fold is None:
            self.fold = nn.Fold(output_size=(H,W), kernel_size=self.patch_size, stride=self.patch_size)
        return self.fold(patches)

class UpCUP(nn.Module):
    """Conv -> PixelShuffle(2x) -> Conv"""
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.up = nn.Sequential(
            nn.Conv2d(in_ch, out_ch*4, kernel_size=3, padding=1),
            nn.PixelShuffle(2),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
            nn.GELU()
        )
    def forward(self, x): return self.up(x)

class TransUNet(nn.Module):
    def __init__(self, in_ch=1, out_ch=1, base_ch=64):
        super().__init__()
        # Encoder
        self.enc1 = ConvBlock(in_ch, base_ch)
        self.enc2 = ConvBlock(base_ch, base_ch*2)
        self.enc3 = ConvBlock(base_ch*2, base_ch*4)
        self.pool = nn.MaxPool2d(2)
        # Bottleneck
        self.bottleneck = TransformerBottleneck(base_ch*4, patch_size=16, num_layers=2, num_heads=4)
        # Decoder (CUP + skips)
        self.up3  = UpCUP(base_ch*4, base_ch*2)
        self.dec3 = ConvBlock(base_ch*2 + base_ch*4, base_ch*2)
        self.up2  = UpCUP(base_ch*2, base_ch)
        self.dec2 = ConvBlock(base_ch + base_ch*2, base_ch)
        self.up1  = UpCUP(base_ch, base_ch//2)
        self.dec1 = ConvBlock(base_ch//2 + base_ch, base_ch//2)
        self.out_conv = nn.Conv2d(base_ch//2, out_ch, kernel_size=1)
    def forward(self, x):
        e1 = self.enc1(x)                      # [B,64,H,W]
        e2 = self.enc2(self.pool(e1))          # [B,128,H/2,W/2]
        e3 = self.enc3(self.pool(e2))          # [B,256,H/4,W/4]
        b  = self.bottleneck(self.pool(e3))    # [B,256,H/8,W/8]
        d3 = self.dec3(torch.cat([self.up3(b),  e3], dim=1))
        d2 = self.dec2(torch.cat([self.up2(d3), e2], dim=1))
        d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))
        return self.out_conv(d1)


# =========================
# HELPERS
# =========================
def calculate_psnr(img1, img2):
    img1, img2 = img1.astype(np.float32), img2.astype(np.float32)
    mse = np.mean((img1 - img2) ** 2)
    if mse == 0: return float('inf')
    return 20 * log10(255.0 / np.sqrt(mse))

def calculate_ssim(img1, img2):
    return ssim(img1, img2, data_range=255)

def numerical_sort(files):
    return sorted(files, key=lambda f: int(re.search(r'\d+', f).group()))

def run_model(img, model):
    """Runs inference with proper scaling and output normalization"""
    tensor = torch.tensor(img, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        pred = model(tensor).cpu().numpy()
    pred = np.squeeze(pred)
    # scale back to [0,255]
    pred = (pred - pred.min()) / (pred.max() - pred.min() + 1e-8)
    pred = (pred * 255.0).astype(np.uint8)
    return pred

# =========================
# LOAD MODEL
# =========================
model = TransUNet().to(DEVICE)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.eval()

# =========================
# INFERENCE 2-STAGE
# =========================
files_650 = numerical_sort([f for f in os.listdir(DATASET_DIR) if f.endswith("900_bscan.png")])
files_1083 = numerical_sort([f for f in os.listdir(DATASET_DIR) if f.endswith("1500_bscan.png")])
files_1800 = numerical_sort([f for f in os.listdir(DATASET_DIR) if f.endswith("2500_bscan.png")])

psnr_stage1, ssim_stage1 = [], []
psnr_stage2, ssim_stage2 = [], []

pred_dir1 = os.path.join(DATASET_DIR, "predictions_stage1_fk_transunet")
pred_dir2 = os.path.join(DATASET_DIR, "predictions_stage2_fk_transunet")
os.makedirs(pred_dir1, exist_ok=True)
os.makedirs(pred_dir2, exist_ok=True)

for i in range(len(files_650)):
    # ---- Stage 1: 650 -> pred -> compare with 1083 ----
    lr_img = np.array(Image.open(os.path.join(DATASET_DIR, files_650[i])).resize(IMAGE_SIZE), dtype=np.float32) / 255.0
    gt_1083 = cv2.imread(os.path.join(DATASET_DIR, files_1083[i]), cv2.IMREAD_GRAYSCALE)

    pred_1083 = run_model(lr_img, model)
    Image.fromarray(pred_1083).save(os.path.join(pred_dir1, f"pred1_{i+1}.png"))

    if pred_1083.shape != gt_1083.shape:
        pred_1083 = cv2.resize(pred_1083, (gt_1083.shape[1], gt_1083.shape[0]))

    psnr_stage1.append(calculate_psnr(pred_1083, gt_1083))
    ssim_stage1.append(calculate_ssim(pred_1083, gt_1083))

    # ---- Stage 2: pred_1083 -> pred -> compare with 1800 ----
    gt_1800 = cv2.imread(os.path.join(DATASET_DIR, files_1800[i]), cv2.IMREAD_GRAYSCALE)
    pred_1083_resized = cv2.resize(pred_1083, IMAGE_SIZE).astype(np.float32) / 255.0
    pred_1800 = run_model(pred_1083_resized, model)
    Image.fromarray(pred_1800).save(os.path.join(pred_dir2, f"pred2_{i+1}.png"))

    if pred_1800.shape != gt_1800.shape:
        pred_1800 = cv2.resize(pred_1800, (gt_1800.shape[1], gt_1800.shape[0]))

    psnr_stage2.append(calculate_psnr(pred_1800, gt_1800))
    ssim_stage2.append(calculate_ssim(pred_1800, gt_1800))

# =========================
# RESULTS
# =========================
print("\n---- Stage 1: 900 → pred → compare with 1500 ----")
print(f"SSIM: avg={np.mean(ssim_stage1):.4f}, min={np.min(ssim_stage1):.4f}, max={np.max(ssim_stage1):.4f}")
print(f"PSNR: avg={np.mean(psnr_stage1):.2f} dB, min={np.min(psnr_stage1):.2f} dB, max={np.max(psnr_stage1):.2f} dB")

print("\n---- Stage 2: pred_1500 → pred → compare with 2500 ----")
print(f"SSIM: avg={np.mean(ssim_stage2):.4f}, min={np.min(ssim_stage2):.4f}, max={np.max(ssim_stage2):.4f}")
print(f"PSNR: avg={np.mean(psnr_stage2):.2f} dB, min={np.min(psnr_stage2):.2f} dB, max={np.max(psnr_stage2):.2f} dB")


---- Stage 1: 900 → pred → compare with 1500 ----
SSIM: avg=0.7392, min=0.5940, max=0.8697
PSNR: avg=17.32 dB, min=15.82 dB, max=18.08 dB

---- Stage 2: pred_1500 → pred → compare with 2500 ----
SSIM: avg=0.7404, min=0.6277, max=0.8415
PSNR: avg=15.55 dB, min=14.07 dB, max=16.20 dB


------------

---------------

----------------------

---------------

clean dataset

In [1]:
import os
import math
import cv2
import numpy as np
from PIL import Image
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

# ==========================
# CONFIG
# ==========================
DATASET_DIR = r"C:\Preet\clean_paired_bscans"  # *_l.png and *_h.png paired images
IMAGE_SIZE = (256, 256)
BATCH_SIZE = 8
EPOCHS = 75
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Running on:", DEVICE)

MODEL_PATH = os.path.join(DATASET_DIR, "fk_transunet++.pth")
RESULTS_DIR = os.path.join(DATASET_DIR, "predictions_fk_transunet")
os.makedirs(RESULTS_DIR, exist_ok=True)

# f–k loss weights / options (tune if needed)
FK_WEIGHT = 0.5           # contribution of f–k term to total loss
FK_USE_LOG = True         # use log magnitude for stability
FK_USE_HANN = True        # apply 2D Hann window before FFT to reduce leakage
FK_KY = (0.05, 0.85)      # rectangular band in normalized ky (temporal freq)
FK_KX = (0.00, 0.90)      # rectangular band in normalized kx (spatial wavenumber)

# ==========================
# DATASET
# ==========================
class GPRDataset(Dataset):
    def __init__(self, x_paths, y_paths):
        self.x_paths = x_paths
        self.y_paths = y_paths

    def __len__(self):
        return len(self.x_paths)

    def __getitem__(self, idx):
        x = np.array(Image.open(self.x_paths[idx]).resize(IMAGE_SIZE), dtype=np.float32) / 255.0
        y = np.array(Image.open(self.y_paths[idx]).resize(IMAGE_SIZE), dtype=np.float32) / 255.0
        x = torch.tensor(x).unsqueeze(0)  # (1,H,W)
        y = torch.tensor(y).unsqueeze(0)
        return x, y

def load_data(dataset_dir):
    low_paths, high_paths = [], []
    for file in os.listdir(dataset_dir):
        if file.endswith("_l.png"):
            low_path = os.path.join(dataset_dir, file)
            high_path = os.path.join(dataset_dir, file.replace("_l.png", "_h.png"))
            if os.path.exists(high_path):
                low_paths.append(low_path)
                high_paths.append(high_path)
    return low_paths, high_paths

all_x, all_y = load_data(DATASET_DIR)

# Split 70/15/15
train_x, temp_x, train_y, temp_y = train_test_split(all_x, all_y, test_size=0.30, random_state=42)
val_x, test_x, val_y, test_y = train_test_split(temp_x, temp_y, test_size=0.50, random_state=42)

print(f"Train: {len(train_x)}, Val: {len(val_x)}, Test: {len(test_x)}")

train_loader = DataLoader(GPRDataset(train_x, train_y), batch_size=BATCH_SIZE, shuffle=True)
val_loader   = DataLoader(GPRDataset(val_x, val_y), batch_size=BATCH_SIZE, shuffle=False)
test_loader  = DataLoader(GPRDataset(test_x, test_y), batch_size=1, shuffle=False)

# ==========================
# MODEL BLOCKS
# ==========================
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.conv(x)

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads=4, ff_dim=512, dropout=0.1):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, embed_dim)
        )
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.drop = nn.Dropout(dropout)

    def forward(self, x):
        attn_out, _ = self.attn(x, x, x)
        x = self.norm1(x + self.drop(attn_out))
        ff_out = self.ff(x)
        x = self.norm2(x + self.drop(ff_out))
        return x

class TransformerBottleneck(nn.Module):
    def __init__(self, in_ch, patch_size=16, num_layers=2, num_heads=4):
        super().__init__()
        self.patch_size = patch_size
        self.embed_dim = in_ch

        self.flatten = nn.Unfold(kernel_size=patch_size, stride=patch_size)
        self.project = nn.Linear(patch_size * patch_size * in_ch, in_ch)

        self.transformer = nn.Sequential(*[
            TransformerBlock(embed_dim=in_ch, num_heads=num_heads) for _ in range(num_layers)
        ])

        self.reconstruct = nn.Linear(in_ch, patch_size * patch_size * in_ch)
        self.fold = None  # initialized at runtime

    def forward(self, x):
        B, C, H, W = x.shape
        patches = self.flatten(x).transpose(1, 2)  # (B, N, patch_dim)
        patches = self.project(patches)
        patches = self.transformer(patches)
        patches = self.reconstruct(patches).transpose(1, 2)
        if self.fold is None:
            self.fold = nn.Fold(output_size=(H, W),
                                kernel_size=self.patch_size,
                                stride=self.patch_size)
        x_reconstructed = self.fold(patches)
        return x_reconstructed

# ---- NEW: CUP upsampler block ----
class UpCUP(nn.Module):
    """Detail-preserving upsampler: conv -> pixelshuffle(2x) -> conv"""
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.up = nn.Sequential(
            nn.Conv2d(in_ch, out_ch * 4, kernel_size=3, padding=1),
            nn.PixelShuffle(2),                         # (H,W) -> (2H,2W)
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
            nn.GELU()
        )
    def forward(self, x):
        return self.up(x)

class TransUNet(nn.Module):
    def __init__(self, in_ch=1, out_ch=1, base_ch=64):
        super().__init__()
        # ================= Encoder =================
        self.enc1 = ConvBlock(in_ch, base_ch)            # 1 -> 64
        self.enc2 = ConvBlock(base_ch, base_ch*2)        # 64 -> 128
        self.enc3 = ConvBlock(base_ch*2, base_ch*4)      # 128 -> 256
        self.pool = nn.MaxPool2d(2)

        # ================= Bottleneck =================
        self.bottleneck = TransformerBottleneck(base_ch*4, patch_size=16, num_layers=2, num_heads=4)

        # ================= Decoder (CUP) =================
        # stage with e3 skip
        self.up3  = UpCUP(base_ch*4, base_ch*2)                         # 256 -> 128 @ H/4
        self.dec3 = ConvBlock(base_ch*2 + base_ch*4, base_ch*2)         # [128 + 256] -> 128
        # stage with e2 skip
        self.up2  = UpCUP(base_ch*2, base_ch)                           # 128 -> 64  @ H/2
        self.dec2 = ConvBlock(base_ch + base_ch*2, base_ch)             # [64 + 128] -> 64
        # NEW: stage with e1 (high-res) skip
        self.up1  = UpCUP(base_ch, base_ch//2)                          # 64 -> 32  @ H
        self.dec1 = ConvBlock(base_ch//2 + base_ch, base_ch//2)         # [32 + 64] -> 32

        self.out_conv = nn.Conv2d(base_ch//2, out_ch, kernel_size=1)    # 32 -> 1

    def forward(self, x):
        # ---------------- Encoder ----------------
        e1 = self.enc1(x)                       # [B,64,H,W]
        e2 = self.enc2(self.pool(e1))           # [B,128,H/2,W/2]
        e3 = self.enc3(self.pool(e2))           # [B,256,H/4,W/4]

        # ---------------- Bottleneck ----------------
        b = self.bottleneck(self.pool(e3))      # [B,256,H/8,W/8]

        # ---------------- Decoder (CUP + all skips) ----------------
        d3 = self.dec3(torch.cat([self.up3(b),  e3], dim=1))   # -> [B,128,H/4,W/4]
        d2 = self.dec2(torch.cat([self.up2(d3), e2], dim=1))   # -> [B,64, H/2,W/2]
        d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))   # -> [B,32, H,  W]
        out = self.out_conv(d1)                                # -> [B,1,  H,  W]
        return out

# ==========================
# f–k LOSS (frequency–wavenumber consistency)
# ==========================
def make_rect_fk_mask(H, W, device, ky=(0.05, 0.85), kx=(0.00, 0.90)):
    """
    Simple rectangular bandpass in normalized frequency space.
    ky, kx are fractions of Nyquist in [0,1].
    """
    yy = torch.linspace(-1, 1, H, device=device).abs().view(H, 1).expand(H, W)
    xx = torch.linspace(-1, 1, W, device=device).abs().view(1, W).expand(H, W)
    m = ((yy >= ky[0]) & (yy <= ky[1]) & (xx >= kx[0]) & (xx <= kx[1])).float()
    return m.unsqueeze(0).unsqueeze(0)  # (1,1,H,W)

def fk_loss(pred, target, bandpass=None, p=1, use_log=True, window=True):
    """
    pred/target: (B,1,H,W) in [0,1]
    Compare spectra in f–k domain (2D FFT). L1 by default.
    """
    assert pred.shape == target.shape
    if window:
        H, W = pred.shape[-2:]
        wy = torch.hann_window(H, device=pred.device).view(1,1,H,1)
        wx = torch.hann_window(W, device=pred.device).view(1,1,1,W)
        win = wy * wx
        pred = pred * win
        target = target * win

    P = torch.fft.fft2(pred, norm='ortho')
    T = torch.fft.fft2(target, norm='ortho')
    Pm = torch.abs(P)
    Tm = torch.abs(T)
    if use_log:
        Pm = torch.log1p(Pm)
        Tm = torch.log1p(Tm)
    diff = (Pm - Tm).abs() if p == 1 else (Pm - Tm).pow(2)
    if bandpass is not None:
        diff = diff * bandpass.to(diff.dtype)
    return diff.mean()

# ==========================
# TRAINING
# ==========================
model = TransUNet().to(DEVICE)
huber = nn.HuberLoss(delta=1.0)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

best_val_loss = float("inf")

for epoch in range(EPOCHS):
    model.train()
    train_loss = 0.0
    for x, y in train_loader:
        x, y = x.to(DEVICE), y.to(DEVICE)
        optimizer.zero_grad()

        preds = model(x)

        # Combine Huber + f–k loss (compute f–k on clamped [0,1] for stability)
        fk_mask = make_rect_fk_mask(y.shape[2], y.shape[3], y.device, ky=FK_KY, kx=FK_KX)
        loss_huber = huber(preds, y)
        loss_fk = fk_loss(preds.clamp(0,1), y, bandpass=fk_mask, p=1, use_log=FK_USE_LOG, window=FK_USE_HANN)
        loss = loss_huber + FK_WEIGHT * loss_fk

        loss.backward()
        optimizer.step()
        train_loss += loss.item() * x.size(0)
    train_loss /= len(train_loader.dataset)

    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for x, y in val_loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            preds = model(x)
            fk_mask = make_rect_fk_mask(y.shape[2], y.shape[3], y.device, ky=FK_KY, kx=FK_KX)
            loss_huber = huber(preds, y)
            loss_fk = fk_loss(preds.clamp(0,1), y, bandpass=fk_mask, p=1, use_log=FK_USE_LOG, window=FK_USE_HANN)
            val_loss += (loss_huber + FK_WEIGHT * loss_fk).item() * x.size(0)
    val_loss /= len(val_loader.dataset)

    print(f"Epoch [{epoch+1}/{EPOCHS}] Train {train_loss:.6f} | Val {val_loss:.6f}")

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), MODEL_PATH)
        print(f"  ✅ Saved Best Model at Epoch {epoch+1}")

# ==========================
# INFERENCE
# ==========================
print("\nRunning inference on test set...")
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.eval()

os.makedirs(RESULTS_DIR, exist_ok=True)

for i, (x, y) in enumerate(test_loader):
    x = x.to(DEVICE)
    with torch.no_grad():
        pred = model(x).cpu().squeeze(0).squeeze(0).numpy()
    pred_img = (np.clip(pred, 0.0, 1.0) * 255.0).astype(np.uint8)
    Image.fromarray(pred_img).save(os.path.join(RESULTS_DIR, f"pred_{i+1}.png"))

print(f"Predictions saved in {RESULTS_DIR}")



Running on: cuda
Train: 5810, Val: 1245, Test: 1245
Epoch [1/75] Train 0.004010 | Val 0.000176
  ✅ Saved Best Model at Epoch 1
Epoch [2/75] Train 0.000105 | Val 0.000086
  ✅ Saved Best Model at Epoch 2
Epoch [3/75] Train 0.000084 | Val 0.000077
  ✅ Saved Best Model at Epoch 3
Epoch [4/75] Train 0.000077 | Val 0.000072
  ✅ Saved Best Model at Epoch 4
Epoch [5/75] Train 0.000075 | Val 0.000069
  ✅ Saved Best Model at Epoch 5
Epoch [6/75] Train 0.000072 | Val 0.000067
  ✅ Saved Best Model at Epoch 6
Epoch [7/75] Train 0.000070 | Val 0.000064
  ✅ Saved Best Model at Epoch 7
Epoch [8/75] Train 0.000066 | Val 0.000068
Epoch [9/75] Train 0.000062 | Val 0.000064
  ✅ Saved Best Model at Epoch 9
Epoch [10/75] Train 0.000055 | Val 0.000053
  ✅ Saved Best Model at Epoch 10
Epoch [11/75] Train 0.000052 | Val 0.000049
  ✅ Saved Best Model at Epoch 11
Epoch [12/75] Train 0.000051 | Val 0.000049
Epoch [13/75] Train 0.000050 | Val 0.000058
Epoch [14/75] Train 0.000050 | Val 0.000059
Epoch [15/75] Train

In [2]:
import os
import cv2
import numpy as np
from skimage.metrics import structural_similarity as ssim
from math import log10
import shutil

# ========================== PATHS ==========================
RESULTS_DIR = r"C:\Preet\clean_paired_bscans\predictions_fk_transunet"  # predictions from TransUNet
GT_TEST_DIR = r"C:\Preet\clean_paired_bscans\ground_truth_test_fk_transunet"

# Copy ground truth test images to a dedicated folder
os.makedirs(GT_TEST_DIR, exist_ok=True)
for f in test_y:  # test_y comes from your train/val/test split
    shutil.copy(f, GT_TEST_DIR)

# ========================== FUNCTIONS ==========================
def calculate_psnr(img1, img2):
    mse = np.mean((img1 - img2) ** 2)
    if mse == 0:
        return float('inf')
    return 20 * log10(255.0 / np.sqrt(mse))

def calculate_ssim(img1, img2):
    return ssim(img1, img2, data_range=255)

# ========================== MAIN EVALUATION ==========================
psnr_values, ssim_values = [], []

pred_files = sorted([f for f in os.listdir(RESULTS_DIR) if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
gt_files   = sorted([f for f in os.listdir(GT_TEST_DIR) if f.lower().endswith(('.png', '.jpg', '.jpeg'))])

num_pairs = min(len(pred_files), len(gt_files))
if num_pairs == 0:
    print("[Error] No matching image files found in both directories.")
else:
    if len(pred_files) != len(gt_files):
        print(f"[Warning] Different number of images: Predictions={len(pred_files)}, Ground Truth={len(gt_files)}")
        print(f"Evaluating only first {num_pairs} matched pairs.")

    for i in range(num_pairs):
        pred_path = os.path.join(RESULTS_DIR, pred_files[i])
        gt_path   = os.path.join(GT_TEST_DIR, gt_files[i])

        pred_img = cv2.imread(pred_path, cv2.IMREAD_GRAYSCALE)
        gt_img   = cv2.imread(gt_path, cv2.IMREAD_GRAYSCALE)

        if pred_img is None or gt_img is None:
            print(f"[Error] Could not load: {pred_files[i]} or {gt_files[i]}")
            continue

        if pred_img.shape != gt_img.shape:
            pred_img = cv2.resize(pred_img, (gt_img.shape[1], gt_img.shape[0]))

        psnr_values.append(calculate_psnr(pred_img, gt_img))
        ssim_values.append(calculate_ssim(pred_img, gt_img))

    if psnr_values and ssim_values:
        print(f"\n---- Test Set Evaluation ----")
        print(f"SSIM: avg={np.mean(ssim_values):.4f}, min={np.min(ssim_values):.4f}, max={np.max(ssim_values):.4f}")
        print(f"PSNR: avg={np.mean(psnr_values):.2f} dB, min={np.min(psnr_values):.2f} dB, max={np.max(psnr_values):.2f} dB")



---- Test Set Evaluation ----
SSIM: avg=0.9052, min=0.7496, max=0.9984
PSNR: avg=35.91 dB, min=29.22 dB, max=49.25 dB


In [3]:
import os
import re
import numpy as np
from PIL import Image
import cv2
from math import log10
from skimage.metrics import structural_similarity as ssim

import torch
import torch.nn as nn
import torch.nn.functional as F

# =========================
# CONFIG
# =========================
FREQ_DIRS = [
    r"C:\Preet\clean_validation dataset\png_images_650M_1083M",
    r"C:\Preet\clean_validation dataset\png_images_700M_1167M",
    r"C:\Preet\clean_validation dataset\png_images_800M_1333M",
    r"C:\Preet\clean_validation dataset\png_images_850M_1416M",
    r"C:\Preet\clean_validation dataset\png_images_900M_1500M",
]

MODEL_PATH = r"C:\Preet\clean_paired_bscans\fk_transunet++.pth"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMAGE_SIZE = (256, 256)

# =========================
# MODEL BLOCKS (same as main code)
# =========================
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True)
        )
    def forward(self, x): return self.conv(x)

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads=4, ff_dim=512, dropout=0.1):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, embed_dim)
        )
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.drop = nn.Dropout(dropout)
    def forward(self, x):
        attn_out, _ = self.attn(x, x, x)
        x = self.norm1(x + self.drop(attn_out))
        ff_out = self.ff(x)
        x = self.norm2(x + self.drop(ff_out))
        return x

class TransformerBottleneck(nn.Module):
    def __init__(self, in_ch, patch_size=16, num_layers=2, num_heads=4):
        super().__init__()
        self.patch_size = patch_size
        self.embed_dim = in_ch
        self.flatten = nn.Unfold(kernel_size=patch_size, stride=patch_size)
        self.project = nn.Linear(patch_size*patch_size*in_ch, in_ch)
        self.transformer = nn.Sequential(*[
            TransformerBlock(embed_dim=in_ch, num_heads=num_heads) for _ in range(num_layers)
        ])
        self.reconstruct = nn.Linear(in_ch, patch_size*patch_size*in_ch)
        self.fold = None
    def forward(self, x):
        B,C,H,W = x.shape
        patches = self.flatten(x).transpose(1,2)
        patches = self.project(patches)
        patches = self.transformer(patches)
        patches = self.reconstruct(patches).transpose(1,2)
        if self.fold is None:
            self.fold = nn.Fold(output_size=(H,W), kernel_size=self.patch_size, stride=self.patch_size)
        return self.fold(patches)

class UpCUP(nn.Module):
    """Conv -> PixelShuffle(2x) -> Conv"""
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.up = nn.Sequential(
            nn.Conv2d(in_ch, out_ch*4, kernel_size=3, padding=1),
            nn.PixelShuffle(2),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
            nn.GELU()
        )
    def forward(self, x): return self.up(x)

class TransUNet(nn.Module):
    def __init__(self, in_ch=1, out_ch=1, base_ch=64):
        super().__init__()
        # Encoder
        self.enc1 = ConvBlock(in_ch, base_ch)
        self.enc2 = ConvBlock(base_ch, base_ch*2)
        self.enc3 = ConvBlock(base_ch*2, base_ch*4)
        self.pool = nn.MaxPool2d(2)
        # Bottleneck
        self.bottleneck = TransformerBottleneck(base_ch*4, patch_size=16, num_layers=2, num_heads=4)
        # Decoder (CUP + skips)
        self.up3  = UpCUP(base_ch*4, base_ch*2)
        self.dec3 = ConvBlock(base_ch*2 + base_ch*4, base_ch*2)
        self.up2  = UpCUP(base_ch*2, base_ch)
        self.dec2 = ConvBlock(base_ch + base_ch*2, base_ch)
        self.up1  = UpCUP(base_ch, base_ch//2)
        self.dec1 = ConvBlock(base_ch//2 + base_ch, base_ch//2)
        self.out_conv = nn.Conv2d(base_ch//2, out_ch, kernel_size=1)
    def forward(self, x):
        e1 = self.enc1(x)                      # [B,64,H,W]
        e2 = self.enc2(self.pool(e1))          # [B,128,H/2,W/2]
        e3 = self.enc3(self.pool(e2))          # [B,256,H/4,W/4]
        b  = self.bottleneck(self.pool(e3))    # [B,256,H/8,W/8]
        d3 = self.dec3(torch.cat([self.up3(b),  e3], dim=1))
        d2 = self.dec2(torch.cat([self.up2(d3), e2], dim=1))
        d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))
        return self.out_conv(d1)

# =========================
# HELPERS
# =========================
def calculate_psnr(img1, img2):
    mse = np.mean((img1 - img2) ** 2)
    if mse == 0: return float('inf')
    return 20 * log10(255.0 / np.sqrt(mse))

def calculate_ssim(img1, img2):
    return ssim(img1, img2, data_range=255)

def numerical_sort(files):
    return sorted(files, key=lambda f: int(re.search(r'\d+', f).group()))

# =========================
# LOAD MODEL
# =========================
model = TransUNet().to(DEVICE)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.eval()

# =========================
# INFERENCE LOOP
# =========================
for freq_dir in FREQ_DIRS:
    print(f"\n---- Evaluating {os.path.basename(freq_dir)} ----")
    pred_dir = os.path.join(freq_dir, "predictions_fk_transunet")
    os.makedirs(pred_dir, exist_ok=True)

    low_files = numerical_sort([f for f in os.listdir(freq_dir) if f.endswith("_l.png")])
    high_files = numerical_sort([f for f in os.listdir(freq_dir) if f.endswith("_h.png")])

    psnr_values, ssim_values = [], []

    for i in range(len(low_files)):
        low_path = os.path.join(freq_dir, low_files[i])
        high_path = os.path.join(freq_dir, high_files[i])

        # Load LR
        lr_img = np.array(Image.open(low_path).resize(IMAGE_SIZE), dtype=np.float32) / 255.0
        lr_tensor = torch.tensor(lr_img).unsqueeze(0).unsqueeze(0).to(DEVICE)

        # Predict
        with torch.no_grad():
            pred = model(lr_tensor).cpu().squeeze().numpy()

        # Scale to 0-255
        pred_img = (np.clip(pred, 0.0, 1.0) * 255.0).astype(np.uint8)

        # Save prediction
        pred_path = os.path.join(pred_dir, f"pred_{i+1}.png")
        Image.fromarray(pred_img).save(pred_path)

        # Load GT
        gt_img = cv2.imread(high_path, cv2.IMREAD_GRAYSCALE)
        if pred_img.shape != gt_img.shape:
            pred_img = cv2.resize(pred_img, (gt_img.shape[1], gt_img.shape[0]))

        # Metrics
        psnr_values.append(calculate_psnr(pred_img, gt_img))
        ssim_values.append(calculate_ssim(pred_img, gt_img))

    if psnr_values and ssim_values:
        print(f"SSIM: avg={np.mean(ssim_values):.4f}, min={np.min(ssim_values):.4f}, max={np.max(ssim_values):.4f}")
        print(f"PSNR: avg={np.mean(psnr_values):.2f} dB, min={np.min(psnr_values):.2f} dB, max={np.max(psnr_values):.2f} dB")
    else:
        print("[Error] No valid pairs processed.")



---- Evaluating png_images_650M_1083M ----
SSIM: avg=0.9488, min=0.9197, max=0.9592
PSNR: avg=36.03 dB, min=31.83 dB, max=37.06 dB

---- Evaluating png_images_700M_1167M ----
SSIM: avg=0.9861, min=0.8317, max=0.9897
PSNR: avg=39.36 dB, min=31.47 dB, max=40.59 dB

---- Evaluating png_images_800M_1333M ----
SSIM: avg=0.9925, min=0.9829, max=0.9955
PSNR: avg=41.44 dB, min=35.36 dB, max=42.73 dB

---- Evaluating png_images_850M_1416M ----
SSIM: avg=0.9698, min=0.9399, max=0.9783
PSNR: avg=36.40 dB, min=31.82 dB, max=37.46 dB

---- Evaluating png_images_900M_1500M ----
SSIM: avg=0.9388, min=0.8828, max=0.9545
PSNR: avg=35.19 dB, min=30.32 dB, max=36.40 dB


In [4]:
import os
import re
import numpy as np
from PIL import Image
import cv2
from math import log10
from skimage.metrics import structural_similarity as ssim

import torch
import torch.nn as nn
import torch.nn.functional as F

# =========================
# CONFIG
# =========================
DATASET_DIR = r"C:\Preet\clean_validation dataset\png_images_650M_1083M_1800M"
MODEL_PATH = r"C:\Preet\clean_paired_bscans\fk_transunet++.pth"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMAGE_SIZE = (256, 256)

# =========================
# MODEL BLOCKS (same as main code)
# =========================
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True)
        )
    def forward(self, x): return self.conv(x)

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads=4, ff_dim=512, dropout=0.1):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, embed_dim)
        )
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.drop = nn.Dropout(dropout)
    def forward(self, x):
        attn_out, _ = self.attn(x, x, x)
        x = self.norm1(x + self.drop(attn_out))
        ff_out = self.ff(x)
        x = self.norm2(x + self.drop(ff_out))
        return x

class TransformerBottleneck(nn.Module):
    def __init__(self, in_ch, patch_size=16, num_layers=2, num_heads=4):
        super().__init__()
        self.patch_size = patch_size
        self.embed_dim = in_ch
        self.flatten = nn.Unfold(kernel_size=patch_size, stride=patch_size)
        self.project = nn.Linear(patch_size*patch_size*in_ch, in_ch)
        self.transformer = nn.Sequential(*[
            TransformerBlock(embed_dim=in_ch, num_heads=num_heads) for _ in range(num_layers)
        ])
        self.reconstruct = nn.Linear(in_ch, patch_size*patch_size*in_ch)
        self.fold = None
    def forward(self, x):
        B,C,H,W = x.shape
        patches = self.flatten(x).transpose(1,2)
        patches = self.project(patches)
        patches = self.transformer(patches)
        patches = self.reconstruct(patches).transpose(1,2)
        if self.fold is None:
            self.fold = nn.Fold(output_size=(H,W), kernel_size=self.patch_size, stride=self.patch_size)
        return self.fold(patches)

class UpCUP(nn.Module):
    """Conv -> PixelShuffle(2x) -> Conv"""
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.up = nn.Sequential(
            nn.Conv2d(in_ch, out_ch*4, kernel_size=3, padding=1),
            nn.PixelShuffle(2),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
            nn.GELU()
        )
    def forward(self, x): return self.up(x)

class TransUNet(nn.Module):
    def __init__(self, in_ch=1, out_ch=1, base_ch=64):
        super().__init__()
        # Encoder
        self.enc1 = ConvBlock(in_ch, base_ch)
        self.enc2 = ConvBlock(base_ch, base_ch*2)
        self.enc3 = ConvBlock(base_ch*2, base_ch*4)
        self.pool = nn.MaxPool2d(2)
        # Bottleneck
        self.bottleneck = TransformerBottleneck(base_ch*4, patch_size=16, num_layers=2, num_heads=4)
        # Decoder (CUP + skips)
        self.up3  = UpCUP(base_ch*4, base_ch*2)
        self.dec3 = ConvBlock(base_ch*2 + base_ch*4, base_ch*2)
        self.up2  = UpCUP(base_ch*2, base_ch)
        self.dec2 = ConvBlock(base_ch + base_ch*2, base_ch)
        self.up1  = UpCUP(base_ch, base_ch//2)
        self.dec1 = ConvBlock(base_ch//2 + base_ch, base_ch//2)
        self.out_conv = nn.Conv2d(base_ch//2, out_ch, kernel_size=1)
    def forward(self, x):
        e1 = self.enc1(x)                      # [B,64,H,W]
        e2 = self.enc2(self.pool(e1))          # [B,128,H/2,W/2]
        e3 = self.enc3(self.pool(e2))          # [B,256,H/4,W/4]
        b  = self.bottleneck(self.pool(e3))    # [B,256,H/8,W/8]
        d3 = self.dec3(torch.cat([self.up3(b),  e3], dim=1))
        d2 = self.dec2(torch.cat([self.up2(d3), e2], dim=1))
        d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))
        return self.out_conv(d1)


# =========================
# HELPERS
# =========================
def calculate_psnr(img1, img2):
    img1, img2 = img1.astype(np.float32), img2.astype(np.float32)
    mse = np.mean((img1 - img2) ** 2)
    if mse == 0: return float('inf')
    return 20 * log10(255.0 / np.sqrt(mse))

def calculate_ssim(img1, img2):
    return ssim(img1, img2, data_range=255)

def numerical_sort(files):
    return sorted(files, key=lambda f: int(re.search(r'\d+', f).group()))

def run_model(img, model):
    """Runs inference with proper scaling and output normalization"""
    tensor = torch.tensor(img, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        pred = model(tensor).cpu().numpy()
    pred = np.squeeze(pred)
    # scale back to [0,255]
    pred = (pred - pred.min()) / (pred.max() - pred.min() + 1e-8)
    pred = (pred * 255.0).astype(np.uint8)
    return pred

# =========================
# LOAD MODEL
# =========================
model = TransUNet().to(DEVICE)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.eval()

# =========================
# INFERENCE 2-STAGE
# =========================
files_650 = numerical_sort([f for f in os.listdir(DATASET_DIR) if f.endswith("650_bscan.png")])
files_1083 = numerical_sort([f for f in os.listdir(DATASET_DIR) if f.endswith("1083_bscan.png")])
files_1800 = numerical_sort([f for f in os.listdir(DATASET_DIR) if f.endswith("1800_bscan.png")])

psnr_stage1, ssim_stage1 = [], []
psnr_stage2, ssim_stage2 = [], []

pred_dir1 = os.path.join(DATASET_DIR, "predictions_stage1_fk_transunet")
pred_dir2 = os.path.join(DATASET_DIR, "predictions_stage2_fk_transunet")
os.makedirs(pred_dir1, exist_ok=True)
os.makedirs(pred_dir2, exist_ok=True)

for i in range(len(files_650)):
    # ---- Stage 1: 650 -> pred -> compare with 1083 ----
    lr_img = np.array(Image.open(os.path.join(DATASET_DIR, files_650[i])).resize(IMAGE_SIZE), dtype=np.float32) / 255.0
    gt_1083 = cv2.imread(os.path.join(DATASET_DIR, files_1083[i]), cv2.IMREAD_GRAYSCALE)

    pred_1083 = run_model(lr_img, model)
    Image.fromarray(pred_1083).save(os.path.join(pred_dir1, f"pred1_{i+1}.png"))

    if pred_1083.shape != gt_1083.shape:
        pred_1083 = cv2.resize(pred_1083, (gt_1083.shape[1], gt_1083.shape[0]))

    psnr_stage1.append(calculate_psnr(pred_1083, gt_1083))
    ssim_stage1.append(calculate_ssim(pred_1083, gt_1083))

    # ---- Stage 2: pred_1083 -> pred -> compare with 1800 ----
    gt_1800 = cv2.imread(os.path.join(DATASET_DIR, files_1800[i]), cv2.IMREAD_GRAYSCALE)
    pred_1083_resized = cv2.resize(pred_1083, IMAGE_SIZE).astype(np.float32) / 255.0
    pred_1800 = run_model(pred_1083_resized, model)
    Image.fromarray(pred_1800).save(os.path.join(pred_dir2, f"pred2_{i+1}.png"))

    if pred_1800.shape != gt_1800.shape:
        pred_1800 = cv2.resize(pred_1800, (gt_1800.shape[1], gt_1800.shape[0]))

    psnr_stage2.append(calculate_psnr(pred_1800, gt_1800))
    ssim_stage2.append(calculate_ssim(pred_1800, gt_1800))

# =========================
# RESULTS
# =========================
print("\n---- Stage 1: 650 → pred → compare with 1083 ----")
print(f"SSIM: avg={np.mean(ssim_stage1):.4f}, min={np.min(ssim_stage1):.4f}, max={np.max(ssim_stage1):.4f}")
print(f"PSNR: avg={np.mean(psnr_stage1):.2f} dB, min={np.min(psnr_stage1):.2f} dB, max={np.max(psnr_stage1):.2f} dB")

print("\n---- Stage 2: pred_1083 → pred → compare with 1800 ----")
print(f"SSIM: avg={np.mean(ssim_stage2):.4f}, min={np.min(ssim_stage2):.4f}, max={np.max(ssim_stage2):.4f}")
print(f"PSNR: avg={np.mean(psnr_stage2):.2f} dB, min={np.min(psnr_stage2):.2f} dB, max={np.max(psnr_stage2):.2f} dB")



---- Stage 1: 650 → pred → compare with 1083 ----
SSIM: avg=0.9466, min=0.9251, max=0.9574
PSNR: avg=27.01 dB, min=26.07 dB, max=27.40 dB

---- Stage 2: pred_1083 → pred → compare with 1800 ----
SSIM: avg=0.9508, min=0.9226, max=0.9677
PSNR: avg=28.89 dB, min=26.36 dB, max=29.64 dB


In [5]:
import os
import re
import numpy as np
from PIL import Image
import cv2
from math import log10
from skimage.metrics import structural_similarity as ssim

import torch
import torch.nn as nn
import torch.nn.functional as F

# =========================
# CONFIG
# =========================
DATASET_DIR = r"C:\Preet\clean_validation dataset\png_images_900M_1500M_2500M"
MODEL_PATH = r"C:\Preet\clean_paired_bscans\fk_transunet++.pth"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMAGE_SIZE = (256, 256)

# =========================
# MODEL BLOCKS (same as main code)
# =========================
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True)
        )
    def forward(self, x): return self.conv(x)

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads=4, ff_dim=512, dropout=0.1):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, embed_dim)
        )
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.drop = nn.Dropout(dropout)
    def forward(self, x):
        attn_out, _ = self.attn(x, x, x)
        x = self.norm1(x + self.drop(attn_out))
        ff_out = self.ff(x)
        x = self.norm2(x + self.drop(ff_out))
        return x

class TransformerBottleneck(nn.Module):
    def __init__(self, in_ch, patch_size=16, num_layers=2, num_heads=4):
        super().__init__()
        self.patch_size = patch_size
        self.embed_dim = in_ch
        self.flatten = nn.Unfold(kernel_size=patch_size, stride=patch_size)
        self.project = nn.Linear(patch_size*patch_size*in_ch, in_ch)
        self.transformer = nn.Sequential(*[
            TransformerBlock(embed_dim=in_ch, num_heads=num_heads) for _ in range(num_layers)
        ])
        self.reconstruct = nn.Linear(in_ch, patch_size*patch_size*in_ch)
        self.fold = None
    def forward(self, x):
        B,C,H,W = x.shape
        patches = self.flatten(x).transpose(1,2)
        patches = self.project(patches)
        patches = self.transformer(patches)
        patches = self.reconstruct(patches).transpose(1,2)
        if self.fold is None:
            self.fold = nn.Fold(output_size=(H,W), kernel_size=self.patch_size, stride=self.patch_size)
        return self.fold(patches)

class UpCUP(nn.Module):
    """Conv -> PixelShuffle(2x) -> Conv"""
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.up = nn.Sequential(
            nn.Conv2d(in_ch, out_ch*4, kernel_size=3, padding=1),
            nn.PixelShuffle(2),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
            nn.GELU()
        )
    def forward(self, x): return self.up(x)

class TransUNet(nn.Module):
    def __init__(self, in_ch=1, out_ch=1, base_ch=64):
        super().__init__()
        # Encoder
        self.enc1 = ConvBlock(in_ch, base_ch)
        self.enc2 = ConvBlock(base_ch, base_ch*2)
        self.enc3 = ConvBlock(base_ch*2, base_ch*4)
        self.pool = nn.MaxPool2d(2)
        # Bottleneck
        self.bottleneck = TransformerBottleneck(base_ch*4, patch_size=16, num_layers=2, num_heads=4)
        # Decoder (CUP + skips)
        self.up3  = UpCUP(base_ch*4, base_ch*2)
        self.dec3 = ConvBlock(base_ch*2 + base_ch*4, base_ch*2)
        self.up2  = UpCUP(base_ch*2, base_ch)
        self.dec2 = ConvBlock(base_ch + base_ch*2, base_ch)
        self.up1  = UpCUP(base_ch, base_ch//2)
        self.dec1 = ConvBlock(base_ch//2 + base_ch, base_ch//2)
        self.out_conv = nn.Conv2d(base_ch//2, out_ch, kernel_size=1)
    def forward(self, x):
        e1 = self.enc1(x)                      # [B,64,H,W]
        e2 = self.enc2(self.pool(e1))          # [B,128,H/2,W/2]
        e3 = self.enc3(self.pool(e2))          # [B,256,H/4,W/4]
        b  = self.bottleneck(self.pool(e3))    # [B,256,H/8,W/8]
        d3 = self.dec3(torch.cat([self.up3(b),  e3], dim=1))
        d2 = self.dec2(torch.cat([self.up2(d3), e2], dim=1))
        d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))
        return self.out_conv(d1)


# =========================
# HELPERS
# =========================
def calculate_psnr(img1, img2):
    img1, img2 = img1.astype(np.float32), img2.astype(np.float32)
    mse = np.mean((img1 - img2) ** 2)
    if mse == 0: return float('inf')
    return 20 * log10(255.0 / np.sqrt(mse))

def calculate_ssim(img1, img2):
    return ssim(img1, img2, data_range=255)

def numerical_sort(files):
    return sorted(files, key=lambda f: int(re.search(r'\d+', f).group()))

def run_model(img, model):
    """Runs inference with proper scaling and output normalization"""
    tensor = torch.tensor(img, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        pred = model(tensor).cpu().numpy()
    pred = np.squeeze(pred)
    # scale back to [0,255]
    pred = (pred - pred.min()) / (pred.max() - pred.min() + 1e-8)
    pred = (pred * 255.0).astype(np.uint8)
    return pred

# =========================
# LOAD MODEL
# =========================
model = TransUNet().to(DEVICE)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.eval()

# =========================
# INFERENCE 2-STAGE
# =========================
files_650 = numerical_sort([f for f in os.listdir(DATASET_DIR) if f.endswith("900_bscan.png")])
files_1083 = numerical_sort([f for f in os.listdir(DATASET_DIR) if f.endswith("1500_bscan.png")])
files_1800 = numerical_sort([f for f in os.listdir(DATASET_DIR) if f.endswith("2500_bscan.png")])

psnr_stage1, ssim_stage1 = [], []
psnr_stage2, ssim_stage2 = [], []

pred_dir1 = os.path.join(DATASET_DIR, "predictions_stage1_fk_transunet")
pred_dir2 = os.path.join(DATASET_DIR, "predictions_stage2_fk_transunet")
os.makedirs(pred_dir1, exist_ok=True)
os.makedirs(pred_dir2, exist_ok=True)

for i in range(len(files_650)):
    # ---- Stage 1: 650 -> pred -> compare with 1083 ----
    lr_img = np.array(Image.open(os.path.join(DATASET_DIR, files_650[i])).resize(IMAGE_SIZE), dtype=np.float32) / 255.0
    gt_1083 = cv2.imread(os.path.join(DATASET_DIR, files_1083[i]), cv2.IMREAD_GRAYSCALE)

    pred_1083 = run_model(lr_img, model)
    Image.fromarray(pred_1083).save(os.path.join(pred_dir1, f"pred1_{i+1}.png"))

    if pred_1083.shape != gt_1083.shape:
        pred_1083 = cv2.resize(pred_1083, (gt_1083.shape[1], gt_1083.shape[0]))

    psnr_stage1.append(calculate_psnr(pred_1083, gt_1083))
    ssim_stage1.append(calculate_ssim(pred_1083, gt_1083))

    # ---- Stage 2: pred_1083 -> pred -> compare with 1800 ----
    gt_1800 = cv2.imread(os.path.join(DATASET_DIR, files_1800[i]), cv2.IMREAD_GRAYSCALE)
    pred_1083_resized = cv2.resize(pred_1083, IMAGE_SIZE).astype(np.float32) / 255.0
    pred_1800 = run_model(pred_1083_resized, model)
    Image.fromarray(pred_1800).save(os.path.join(pred_dir2, f"pred2_{i+1}.png"))

    if pred_1800.shape != gt_1800.shape:
        pred_1800 = cv2.resize(pred_1800, (gt_1800.shape[1], gt_1800.shape[0]))

    psnr_stage2.append(calculate_psnr(pred_1800, gt_1800))
    ssim_stage2.append(calculate_ssim(pred_1800, gt_1800))

# =========================
# RESULTS
# =========================
print("\n---- Stage 1: 900 → pred → compare with 1500 ----")
print(f"SSIM: avg={np.mean(ssim_stage1):.4f}, min={np.min(ssim_stage1):.4f}, max={np.max(ssim_stage1):.4f}")
print(f"PSNR: avg={np.mean(psnr_stage1):.2f} dB, min={np.min(psnr_stage1):.2f} dB, max={np.max(psnr_stage1):.2f} dB")

print("\n---- Stage 2: pred_1500 → pred → compare with 2500 ----")
print(f"SSIM: avg={np.mean(ssim_stage2):.4f}, min={np.min(ssim_stage2):.4f}, max={np.max(ssim_stage2):.4f}")
print(f"PSNR: avg={np.mean(psnr_stage2):.2f} dB, min={np.min(psnr_stage2):.2f} dB, max={np.max(psnr_stage2):.2f} dB")


---- Stage 1: 900 → pred → compare with 1500 ----
SSIM: avg=0.7426, min=0.5871, max=0.8675
PSNR: avg=17.44 dB, min=15.73 dB, max=18.12 dB

---- Stage 2: pred_1500 → pred → compare with 2500 ----
SSIM: avg=0.7630, min=0.6372, max=0.8696
PSNR: avg=16.75 dB, min=15.39 dB, max=17.40 dB


-----------------

-------------------

In [1]:
import os
import cv2
import torch
import numpy as np
from torch import nn

# ==========================
# CONFIG (paths edited)
# ==========================
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Running on:", DEVICE)

MODEL_PATH   = r"C:\Preet\clean_paired_bscans\fk_transunet++.pth"
IMAGE_DIR    = r"C:\Preet\Real GPR data"  # where the two PNGs live
INPUT_IMAGES = [
    os.path.join(IMAGE_DIR, "cropped_patch_256x256.png"),
]
OUTPUT_DIR   = r"C:\Preet\Real GPR data\FK-transunet\single_image_preds"
TARGET_SIZE  = (256, 256)  # H, W used by the model

os.makedirs(OUTPUT_DIR, exist_ok=True)

# =========================
# MODEL BLOCKS (same as main code)
# =========================
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True)
        )
    def forward(self, x): return self.conv(x)

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads=4, ff_dim=512, dropout=0.1):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, embed_dim)
        )
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.drop = nn.Dropout(dropout)
    def forward(self, x):
        attn_out, _ = self.attn(x, x, x)
        x = self.norm1(x + self.drop(attn_out))
        ff_out = self.ff(x)
        x = self.norm2(x + self.drop(ff_out))
        return x

class TransformerBottleneck(nn.Module):
    def __init__(self, in_ch, patch_size=16, num_layers=2, num_heads=4):
        super().__init__()
        self.patch_size = patch_size
        self.embed_dim = in_ch
        self.flatten = nn.Unfold(kernel_size=patch_size, stride=patch_size)
        self.project = nn.Linear(patch_size*patch_size*in_ch, in_ch)
        self.transformer = nn.Sequential(*[
            TransformerBlock(embed_dim=in_ch, num_heads=num_heads) for _ in range(num_layers)
        ])
        self.reconstruct = nn.Linear(in_ch, patch_size*patch_size*in_ch)
        self.fold = None
    def forward(self, x):
        B,C,H,W = x.shape
        patches = self.flatten(x).transpose(1,2)
        patches = self.project(patches)
        patches = self.transformer(patches)
        patches = self.reconstruct(patches).transpose(1,2)
        if self.fold is None:
            self.fold = nn.Fold(output_size=(H,W), kernel_size=self.patch_size, stride=self.patch_size)
        return self.fold(patches)

class UpCUP(nn.Module):
    """Conv -> PixelShuffle(2x) -> Conv"""
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.up = nn.Sequential(
            nn.Conv2d(in_ch, out_ch*4, kernel_size=3, padding=1),
            nn.PixelShuffle(2),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
            nn.GELU()
        )
    def forward(self, x): return self.up(x)

class TransUNet(nn.Module):
    def __init__(self, in_ch=1, out_ch=1, base_ch=64):
        super().__init__()
        # Encoder
        self.enc1 = ConvBlock(in_ch, base_ch)
        self.enc2 = ConvBlock(base_ch, base_ch*2)
        self.enc3 = ConvBlock(base_ch*2, base_ch*4)
        self.pool = nn.MaxPool2d(2)
        # Bottleneck
        self.bottleneck = TransformerBottleneck(base_ch*4, patch_size=16, num_layers=2, num_heads=4)
        # Decoder (CUP + skips)
        self.up3  = UpCUP(base_ch*4, base_ch*2)
        self.dec3 = ConvBlock(base_ch*2 + base_ch*4, base_ch*2)
        self.up2  = UpCUP(base_ch*2, base_ch)
        self.dec2 = ConvBlock(base_ch + base_ch*2, base_ch)
        self.up1  = UpCUP(base_ch, base_ch//2)
        self.dec1 = ConvBlock(base_ch//2 + base_ch, base_ch//2)
        self.out_conv = nn.Conv2d(base_ch//2, out_ch, kernel_size=1)
    def forward(self, x):
        e1 = self.enc1(x)                      # [B,64,H,W]
        e2 = self.enc2(self.pool(e1))          # [B,128,H/2,W/2]
        e3 = self.enc3(self.pool(e2))          # [B,256,H/4,W/4]
        b  = self.bottleneck(self.pool(e3))    # [B,256,H/8,W/8]
        d3 = self.dec3(torch.cat([self.up3(b),  e3], dim=1))
        d2 = self.dec2(torch.cat([self.up2(d3), e2], dim=1))
        d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))
        return self.out_conv(d1)

# ==========================
# HELPERS
# ==========================
def preprocess(img_path, target_hw=(256, 256)):
    """Load grayscale, ensure 256x256, normalize to [0,1], -> (1,1,H,W) tensor."""
    img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
    if img is None:
        raise FileNotFoundError(f"Cannot read image: {img_path}")
    if img.shape != target_hw:
        img = cv2.resize(img, (target_hw[1], target_hw[0]), interpolation=cv2.INTER_CUBIC)
    img = img.astype(np.float32)
    img = (img - img.min()) / (img.max() - img.min() + 1e-8)
    tensor = torch.from_numpy(img).unsqueeze(0).unsqueeze(0)  # (1,1,H,W)
    return tensor

def save_pred(pred_tensor, save_path):
    pred = pred_tensor.squeeze().cpu().numpy()
    pred = (pred - pred.min()) / (pred.max() - pred.min() + 1e-8)
    pred = (pred * 255).astype(np.uint8)
    cv2.imwrite(save_path, pred)

# ==========================
# MAIN
# ==========================
if __name__ == "__main__":
    # Load model
    model = TransUNet().to(DEVICE)
    model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
    model.eval()
    print("Model loaded:", MODEL_PATH)

    for img_path in INPUT_IMAGES:
        name = os.path.splitext(os.path.basename(img_path))[0]
        out_path = os.path.join(OUTPUT_DIR, f"{name}_pred_256x256.png")

        x = preprocess(img_path, TARGET_SIZE).to(DEVICE)
        with torch.no_grad():
            y = model(x)

        save_pred(y, out_path)
        print(f"✅ Saved prediction for {name}: {out_path}")


Running on: cuda
Model loaded: C:\Preet\clean_paired_bscans\fk_transunet++.pth
✅ Saved prediction for cropped_patch_256x256: C:\Preet\Real GPR data\FK-transunet\single_image_preds\cropped_patch_256x256_pred_256x256.png


test for 400-670Mhz on the trained model of 750-1250Mhz

In [2]:
import os
import re
import numpy as np
from PIL import Image
import cv2
from math import log10
from skimage.metrics import structural_similarity as ssim

import torch
import torch.nn as nn
import torch.nn.functional as F

# =========================
# CONFIG
# =========================
TEST_DIR = r"C:\Preet\400_670_Dataset"
MODEL_PATH = r"C:\Preet\clean_paired_bscans\fk_transunet++.pth"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMAGE_SIZE = (256, 256)

# =========================
# MODEL DEFINITION
# =========================
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True)
        )
    def forward(self, x): return self.conv(x)

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads=4, ff_dim=512, dropout=0.1):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, embed_dim)
        )
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.drop = nn.Dropout(dropout)
    def forward(self, x):
        attn_out, _ = self.attn(x, x, x)
        x = self.norm1(x + self.drop(attn_out))
        ff_out = self.ff(x)
        x = self.norm2(x + self.drop(ff_out))
        return x

class TransformerBottleneck(nn.Module):
    def __init__(self, in_ch, patch_size=16, num_layers=2, num_heads=4):
        super().__init__()
        self.patch_size = patch_size
        self.embed_dim = in_ch
        self.flatten = nn.Unfold(kernel_size=patch_size, stride=patch_size)
        self.project = nn.Linear(patch_size * patch_size * in_ch, in_ch)
        self.transformer = nn.Sequential(*[
            TransformerBlock(embed_dim=in_ch, num_heads=num_heads) for _ in range(num_layers)
        ])
        self.reconstruct = nn.Linear(in_ch, patch_size * patch_size * in_ch)
        self.fold = None

    def forward(self, x):
        B, C, H, W = x.shape
        patches = self.flatten(x).transpose(1, 2)
        patches = self.project(patches)
        patches = self.transformer(patches)
        patches = self.reconstruct(patches).transpose(1, 2)
        if self.fold is None:
            self.fold = nn.Fold(output_size=(H, W), kernel_size=self.patch_size, stride=self.patch_size)
        return self.fold(patches)

class UpCUP(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.up = nn.Sequential(
            nn.Conv2d(in_ch, out_ch * 4, kernel_size=3, padding=1),
            nn.PixelShuffle(2),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
            nn.GELU()
        )
    def forward(self, x): return self.up(x)

class TransUNet(nn.Module):
    def __init__(self, in_ch=1, out_ch=1, base_ch=64):
        super().__init__()
        self.enc1 = ConvBlock(in_ch, base_ch)
        self.enc2 = ConvBlock(base_ch, base_ch * 2)
        self.enc3 = ConvBlock(base_ch * 2, base_ch * 4)
        self.pool = nn.MaxPool2d(2)
        self.bottleneck = TransformerBottleneck(base_ch * 4)
        self.up3 = UpCUP(base_ch * 4, base_ch * 2)
        self.dec3 = ConvBlock(base_ch * 2 + base_ch * 4, base_ch * 2)
        self.up2 = UpCUP(base_ch * 2, base_ch)
        self.dec2 = ConvBlock(base_ch + base_ch * 2, base_ch)
        self.up1 = UpCUP(base_ch, base_ch // 2)
        self.dec1 = ConvBlock(base_ch // 2 + base_ch, base_ch // 2)
        self.out_conv = nn.Conv2d(base_ch // 2, out_ch, kernel_size=1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        b = self.bottleneck(self.pool(e3))
        d3 = self.dec3(torch.cat([self.up3(b), e3], dim=1))
        d2 = self.dec2(torch.cat([self.up2(d3), e2], dim=1))
        d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))
        return self.out_conv(d1)

# =========================
# HELPERS
# =========================
def calculate_psnr(img1, img2):
    mse = np.mean((img1 - img2) ** 2)
    if mse == 0:
        return float('inf')
    return 20 * log10(255.0 / np.sqrt(mse))

def calculate_ssim(img1, img2):
    return ssim(img1, img2, data_range=255)

def numerical_sort(files):
    return sorted(files, key=lambda f: int(re.search(r'\d+', f).group()))

# =========================
# LOAD MODEL
# =========================
model = TransUNet().to(DEVICE)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.eval()

# =========================
# INFERENCE ON TEST FOLDER
# =========================
print(f"\n---- Evaluating {TEST_DIR} ----")
pred_dir = os.path.join(TEST_DIR, "predictions_fk_transunet")
os.makedirs(pred_dir, exist_ok=True)

low_files = numerical_sort([f for f in os.listdir(TEST_DIR) if f.endswith("_l.png")])
high_files = numerical_sort([f for f in os.listdir(TEST_DIR) if f.endswith("_h.png")])

psnr_values, ssim_values = [], []

for i in range(len(low_files)):
    low_path = os.path.join(TEST_DIR, low_files[i])
    high_path = os.path.join(TEST_DIR, high_files[i])

    # Load LR
    lr_img = np.array(Image.open(low_path).resize(IMAGE_SIZE), dtype=np.float32) / 255.0
    lr_tensor = torch.tensor(lr_img).unsqueeze(0).unsqueeze(0).to(DEVICE)

    # Predict
    with torch.no_grad():
        pred = model(lr_tensor).cpu().squeeze().numpy()

    # Scale to 0–255
    pred_img = (np.clip(pred, 0.0, 1.0) * 255.0).astype(np.uint8)
    pred_path = os.path.join(pred_dir, f"pred_{i+1}.png")
    Image.fromarray(pred_img).save(pred_path)

    # Load GT
    gt_img = cv2.imread(high_path, cv2.IMREAD_GRAYSCALE)
    if pred_img.shape != gt_img.shape:
        pred_img = cv2.resize(pred_img, (gt_img.shape[1], gt_img.shape[0]))

    # Metrics
    psnr_values.append(calculate_psnr(pred_img, gt_img))
    ssim_values.append(calculate_ssim(pred_img, gt_img))

if psnr_values and ssim_values:
    print(f"SSIM: avg={np.mean(ssim_values):.4f}, min={np.min(ssim_values):.4f}, max={np.max(ssim_values):.4f}")
    print(f"PSNR: avg={np.mean(psnr_values):.2f} dB, min={np.min(psnr_values):.2f} dB, max={np.max(psnr_values):.2f} dB")
else:
    print("[Error] No valid *_l.png and *_h.png pairs found.")



---- Evaluating C:\Preet\400_670_Dataset ----
SSIM: avg=0.7356, min=0.6056, max=0.8002
PSNR: avg=27.81 dB, min=27.30 dB, max=29.10 dB


------------------------

-----------------

-----------------

------------------

training on 400-670 Mhz dataset size -> 3404 paired images -> approx 3400 paired images -> 1700 images LR 

In [2]:
import os
import math
import cv2
import numpy as np
from PIL import Image
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

# ==========================
# CONFIG
# ==========================
DATASET_DIR = r"C:\Preet\png_images_400M_670M"  # *_l.png and *_h.png paired images
IMAGE_SIZE = (256, 256)
BATCH_SIZE = 8
EPOCHS = 75
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Running on:", DEVICE)

MODEL_PATH = os.path.join(DATASET_DIR, "fk_transunet++.pth")
RESULTS_DIR = os.path.join(DATASET_DIR, "predictions_fk_transunet")
os.makedirs(RESULTS_DIR, exist_ok=True)

# f–k loss weights / options (tune if needed)
FK_WEIGHT = 0.5           # contribution of f–k term to total loss
FK_USE_LOG = True         # use log magnitude for stability
FK_USE_HANN = True        # apply 2D Hann window before FFT to reduce leakage
FK_KY = (0.05, 0.85)      # rectangular band in normalized ky (temporal freq)
FK_KX = (0.00, 0.90)      # rectangular band in normalized kx (spatial wavenumber)

# ==========================
# DATASET
# ==========================
class GPRDataset(Dataset):
    def __init__(self, x_paths, y_paths):
        self.x_paths = x_paths
        self.y_paths = y_paths

    def __len__(self):
        return len(self.x_paths)

    def __getitem__(self, idx):
        x = np.array(Image.open(self.x_paths[idx]).resize(IMAGE_SIZE), dtype=np.float32) / 255.0
        y = np.array(Image.open(self.y_paths[idx]).resize(IMAGE_SIZE), dtype=np.float32) / 255.0
        x = torch.tensor(x).unsqueeze(0)  # (1,H,W)
        y = torch.tensor(y).unsqueeze(0)
        return x, y

def load_data(dataset_dir):
    low_paths, high_paths = [], []
    for file in os.listdir(dataset_dir):
        if file.endswith("_l.png"):
            low_path = os.path.join(dataset_dir, file)
            high_path = os.path.join(dataset_dir, file.replace("_l.png", "_h.png"))
            if os.path.exists(high_path):
                low_paths.append(low_path)
                high_paths.append(high_path)
    return low_paths, high_paths

# ==========================
# DATA SPLIT (80:20 Train/Test)
# ==========================
all_x, all_y = load_data(DATASET_DIR)

train_x, test_x, train_y, test_y = train_test_split(
    all_x, all_y, test_size=0.20, random_state=42
)

print(f"Train: {len(train_x)}, Test: {len(test_x)}")

train_loader = DataLoader(GPRDataset(train_x, train_y), batch_size=BATCH_SIZE, shuffle=True)
test_loader  = DataLoader(GPRDataset(test_x, test_y), batch_size=1, shuffle=False)

# ==========================
# MODEL BLOCKS
# ==========================
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.conv(x)

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads=4, ff_dim=512, dropout=0.1):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, embed_dim)
        )
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.drop = nn.Dropout(dropout)

    def forward(self, x):
        attn_out, _ = self.attn(x, x, x)
        x = self.norm1(x + self.drop(attn_out))
        ff_out = self.ff(x)
        x = self.norm2(x + self.drop(ff_out))
        return x

class TransformerBottleneck(nn.Module):
    def __init__(self, in_ch, patch_size=16, num_layers=2, num_heads=4):
        super().__init__()
        self.patch_size = patch_size
        self.embed_dim = in_ch

        self.flatten = nn.Unfold(kernel_size=patch_size, stride=patch_size)
        self.project = nn.Linear(patch_size * patch_size * in_ch, in_ch)

        self.transformer = nn.Sequential(*[
            TransformerBlock(embed_dim=in_ch, num_heads=num_heads) for _ in range(num_layers)
        ])

        self.reconstruct = nn.Linear(in_ch, patch_size * patch_size * in_ch)
        self.fold = None  # initialized at runtime

    def forward(self, x):
        B, C, H, W = x.shape
        patches = self.flatten(x).transpose(1, 2)  # (B, N, patch_dim)
        patches = self.project(patches)
        patches = self.transformer(patches)
        patches = self.reconstruct(patches).transpose(1, 2)
        if self.fold is None:
            self.fold = nn.Fold(output_size=(H, W),
                                kernel_size=self.patch_size,
                                stride=self.patch_size)
        x_reconstructed = self.fold(patches)
        return x_reconstructed

# ---- NEW: CUP upsampler block ----
class UpCUP(nn.Module):
    """Detail-preserving upsampler: conv -> pixelshuffle(2x) -> conv"""
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.up = nn.Sequential(
            nn.Conv2d(in_ch, out_ch * 4, kernel_size=3, padding=1),
            nn.PixelShuffle(2),                         # (H,W) -> (2H,2W)
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
            nn.GELU()
        )
    def forward(self, x):
        return self.up(x)

class TransUNet(nn.Module):
    def __init__(self, in_ch=1, out_ch=1, base_ch=64):
        super().__init__()
        # ================= Encoder =================
        self.enc1 = ConvBlock(in_ch, base_ch)            # 1 -> 64
        self.enc2 = ConvBlock(base_ch, base_ch*2)        # 64 -> 128
        self.enc3 = ConvBlock(base_ch*2, base_ch*4)      # 128 -> 256
        self.pool = nn.MaxPool2d(2)

        # ================= Bottleneck =================
        self.bottleneck = TransformerBottleneck(base_ch*4, patch_size=16, num_layers=2, num_heads=4)

        # ================= Decoder (CUP) =================
        # stage with e3 skip
        self.up3  = UpCUP(base_ch*4, base_ch*2)                         # 256 -> 128 @ H/4
        self.dec3 = ConvBlock(base_ch*2 + base_ch*4, base_ch*2)         # [128 + 256] -> 128
        # stage with e2 skip
        self.up2  = UpCUP(base_ch*2, base_ch)                           # 128 -> 64  @ H/2
        self.dec2 = ConvBlock(base_ch + base_ch*2, base_ch)             # [64 + 128] -> 64
        # NEW: stage with e1 (high-res) skip
        self.up1  = UpCUP(base_ch, base_ch//2)                          # 64 -> 32  @ H
        self.dec1 = ConvBlock(base_ch//2 + base_ch, base_ch//2)         # [32 + 64] -> 32

        self.out_conv = nn.Conv2d(base_ch//2, out_ch, kernel_size=1)    # 32 -> 1

    def forward(self, x):
        # ---------------- Encoder ----------------
        e1 = self.enc1(x)                       # [B,64,H,W]
        e2 = self.enc2(self.pool(e1))           # [B,128,H/2,W/2]
        e3 = self.enc3(self.pool(e2))           # [B,256,H/4,W/4]

        # ---------------- Bottleneck ----------------
        b = self.bottleneck(self.pool(e3))      # [B,256,H/8,W/8]

        # ---------------- Decoder (CUP + all skips) ----------------
        d3 = self.dec3(torch.cat([self.up3(b),  e3], dim=1))   # -> [B,128,H/4,W/4]
        d2 = self.dec2(torch.cat([self.up2(d3), e2], dim=1))   # -> [B,64, H/2,W/2]
        d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))   # -> [B,32, H,  W]
        out = self.out_conv(d1)                                # -> [B,1,  H,  W]
        return out

# ==========================
# f–k LOSS (frequency–wavenumber consistency)
# ==========================
def make_rect_fk_mask(H, W, device, ky=(0.05, 0.85), kx=(0.00, 0.90)):
    """
    Simple rectangular bandpass in normalized frequency space.
    ky, kx are fractions of Nyquist in [0,1].
    """
    yy = torch.linspace(-1, 1, H, device=device).abs().view(H, 1).expand(H, W)
    xx = torch.linspace(-1, 1, W, device=device).abs().view(1, W).expand(H, W)
    m = ((yy >= ky[0]) & (yy <= ky[1]) & (xx >= kx[0]) & (xx <= kx[1])).float()
    return m.unsqueeze(0).unsqueeze(0)  # (1,1,H,W)

def fk_loss(pred, target, bandpass=None, p=1, use_log=True, window=True):
    """
    pred/target: (B,1,H,W) in [0,1]
    Compare spectra in f–k domain (2D FFT). L1 by default.
    """
    assert pred.shape == target.shape
    if window:
        H, W = pred.shape[-2:]
        wy = torch.hann_window(H, device=pred.device).view(1,1,H,1)
        wx = torch.hann_window(W, device=pred.device).view(1,1,1,W)
        win = wy * wx
        pred = pred * win
        target = target * win

    P = torch.fft.fft2(pred, norm='ortho')
    T = torch.fft.fft2(target, norm='ortho')
    Pm = torch.abs(P)
    Tm = torch.abs(T)
    if use_log:
        Pm = torch.log1p(Pm)
        Tm = torch.log1p(Tm)
    diff = (Pm - Tm).abs() if p == 1 else (Pm - Tm).pow(2)
    if bandpass is not None:
        diff = diff * bandpass.to(diff.dtype)
    return diff.mean()

# ==========================
# TRAINING
# ==========================
model = TransUNet().to(DEVICE)
huber = nn.HuberLoss(delta=1.0)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

best_train_loss = float("inf")

for epoch in range(EPOCHS):
    model.train()
    train_loss = 0.0
    for x, y in train_loader:
        x, y = x.to(DEVICE), y.to(DEVICE)
        optimizer.zero_grad()

        preds = model(x)

        # Combine Huber + f–k loss
        fk_mask = make_rect_fk_mask(y.shape[2], y.shape[3], y.device, ky=FK_KY, kx=FK_KX)
        loss_huber = huber(preds, y)
        loss_fk = fk_loss(preds.clamp(0,1), y, bandpass=fk_mask,
                          p=1, use_log=FK_USE_LOG, window=FK_USE_HANN)
        loss = loss_huber + FK_WEIGHT * loss_fk

        loss.backward()
        optimizer.step()
        train_loss += loss.item() * x.size(0)

    train_loss /= len(train_loader.dataset)

    print(f"Epoch [{epoch+1}/{EPOCHS}] Train {train_loss:.6f}")

    # Save best model based on training loss
    if train_loss < best_train_loss:
        best_train_loss = train_loss
        torch.save(model.state_dict(), MODEL_PATH)
        print(f"  ✅ Saved Best Model at Epoch {epoch+1}")


# ==========================
# INFERENCE (Save with matching test image numbers)
# ==========================
print("\nRunning inference on test set...")
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.eval()

os.makedirs(RESULTS_DIR, exist_ok=True)

for x_path, y_path in zip(test_x, test_y):
    # Extract numeric ID from filename (e.g., '310' from '310bscan_l.png')
    base_name = os.path.basename(x_path)
    num_label = ''.join([c for c in base_name if c.isdigit()])
    if not num_label:
        num_label = os.path.splitext(base_name)[0]  # fallback to full name

    x = Image.open(x_path).convert('L')
    x = np.array(x.resize(IMAGE_SIZE), dtype=np.float32) / 255.0
    x = torch.tensor(x).unsqueeze(0).unsqueeze(0).to(DEVICE)

    with torch.no_grad():
        pred = model(x).cpu().squeeze().numpy()

    pred_img = (np.clip(pred, 0.0, 1.0) * 255.0).astype(np.uint8)
    save_path = os.path.join(RESULTS_DIR, f"pred_{num_label}.png")
    Image.fromarray(pred_img).save(save_path)

print(f"✅ Predictions saved with matching labels in {RESULTS_DIR}")




Running on: cuda
Train: 1361, Test: 341
Epoch [1/75] Train 0.022333
  ✅ Saved Best Model at Epoch 1
Epoch [2/75] Train 0.000778
  ✅ Saved Best Model at Epoch 2
Epoch [3/75] Train 0.000553
  ✅ Saved Best Model at Epoch 3
Epoch [4/75] Train 0.000510
  ✅ Saved Best Model at Epoch 4
Epoch [5/75] Train 0.000378
  ✅ Saved Best Model at Epoch 5
Epoch [6/75] Train 0.000239
  ✅ Saved Best Model at Epoch 6
Epoch [7/75] Train 0.000207
  ✅ Saved Best Model at Epoch 7
Epoch [8/75] Train 0.000165
  ✅ Saved Best Model at Epoch 8
Epoch [9/75] Train 0.000154
  ✅ Saved Best Model at Epoch 9
Epoch [10/75] Train 0.000149
  ✅ Saved Best Model at Epoch 10
Epoch [11/75] Train 0.000135
  ✅ Saved Best Model at Epoch 11
Epoch [12/75] Train 0.000125
  ✅ Saved Best Model at Epoch 12
Epoch [13/75] Train 0.000119
  ✅ Saved Best Model at Epoch 13
Epoch [14/75] Train 0.000124
Epoch [15/75] Train 0.000114
  ✅ Saved Best Model at Epoch 15
Epoch [16/75] Train 0.000174
Epoch [17/75] Train 0.000112
  ✅ Saved Best Model at

In [3]:
import os
import cv2
import numpy as np
from skimage.metrics import structural_similarity as ssim
from math import log10
import shutil

# ========================== PATHS ==========================
RESULTS_DIR = r"C:\Preet\png_images_400M_670M\predictions_fk_transunet"  # predictions from TransUNet
GT_TEST_DIR = r"C:\Preet\png_images_400M_670M\ground_truth_test_fk_transunet"

# Copy ground truth test images to a dedicated folder
os.makedirs(GT_TEST_DIR, exist_ok=True)
for f in test_y:  # test_y comes from your train/val/test split
    shutil.copy(f, GT_TEST_DIR)

# ========================== FUNCTIONS ==========================
def calculate_psnr(img1, img2):
    mse = np.mean((img1 - img2) ** 2)
    if mse == 0:
        return float('inf')
    return 20 * log10(255.0 / np.sqrt(mse))

def calculate_ssim(img1, img2):
    return ssim(img1, img2, data_range=255)

# ========================== MAIN EVALUATION ==========================
psnr_values, ssim_values = [], []

pred_files = sorted([f for f in os.listdir(RESULTS_DIR) if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
gt_files   = sorted([f for f in os.listdir(GT_TEST_DIR) if f.lower().endswith(('.png', '.jpg', '.jpeg'))])

num_pairs = min(len(pred_files), len(gt_files))
if num_pairs == 0:
    print("[Error] No matching image files found in both directories.")
else:
    if len(pred_files) != len(gt_files):
        print(f"[Warning] Different number of images: Predictions={len(pred_files)}, Ground Truth={len(gt_files)}")
        print(f"Evaluating only first {num_pairs} matched pairs.")

    for i in range(num_pairs):
        pred_path = os.path.join(RESULTS_DIR, pred_files[i])
        gt_path   = os.path.join(GT_TEST_DIR, gt_files[i])

        pred_img = cv2.imread(pred_path, cv2.IMREAD_GRAYSCALE)
        gt_img   = cv2.imread(gt_path, cv2.IMREAD_GRAYSCALE)

        if pred_img is None or gt_img is None:
            print(f"[Error] Could not load: {pred_files[i]} or {gt_files[i]}")
            continue

        if pred_img.shape != gt_img.shape:
            pred_img = cv2.resize(pred_img, (gt_img.shape[1], gt_img.shape[0]))

        psnr_values.append(calculate_psnr(pred_img, gt_img))
        ssim_values.append(calculate_ssim(pred_img, gt_img))

    if psnr_values and ssim_values:
        print(f"\n---- Test Set Evaluation ----")
        print(f"SSIM: avg={np.mean(ssim_values):.4f}, min={np.min(ssim_values):.4f}, max={np.max(ssim_values):.4f}")
        print(f"PSNR: avg={np.mean(psnr_values):.2f} dB, min={np.min(psnr_values):.2f} dB, max={np.max(psnr_values):.2f} dB")


---- Test Set Evaluation ----
SSIM: avg=0.9583, min=0.8245, max=0.9983
PSNR: avg=42.84 dB, min=29.97 dB, max=54.73 dB


In [4]:
# test real gpr data from mendaley's
import os
import numpy as np
from PIL import Image
import torch
import torch.nn as nn

# ==========================
# CONFIG
# ==========================
TEST_DIR   = r"C:\Preet\Real GPR data\mendaley\Utilities_resized"  # folder with real GPR images
IMAGE_SIZE = (256, 256)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

MODEL_PATH = r"C:\Preet\png_images_400M_670M\fk_transunet++.pth"   # trained weights
RESULTS_DIR = r"C:\Preet\Real GPR data\mendaley\Utilities_predictions_fk"
os.makedirs(RESULTS_DIR, exist_ok=True)


# ==========================
# MODEL BLOCKS (copied from training script)
# ==========================
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.conv(x)

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads=4, ff_dim=512, dropout=0.1):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, embed_dim)
        )
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.drop = nn.Dropout(dropout)
    def forward(self, x):
        attn_out, _ = self.attn(x, x, x)
        x = self.norm1(x + self.drop(attn_out))
        ff_out = self.ff(x)
        x = self.norm2(x + self.drop(ff_out))
        return x

class TransformerBottleneck(nn.Module):
    def __init__(self, in_ch, patch_size=16, num_layers=2, num_heads=4):
        super().__init__()
        self.patch_size = patch_size
        self.embed_dim = in_ch
        self.flatten = nn.Unfold(kernel_size=patch_size, stride=patch_size)
        self.project = nn.Linear(patch_size * patch_size * in_ch, in_ch)
        self.transformer = nn.Sequential(*[
            TransformerBlock(embed_dim=in_ch, num_heads=num_heads) for _ in range(num_layers)
        ])
        self.reconstruct = nn.Linear(in_ch, patch_size * patch_size * in_ch)
        self.fold = None
    def forward(self, x):
        B, C, H, W = x.shape
        patches = self.flatten(x).transpose(1, 2)
        patches = self.project(patches)
        patches = self.transformer(patches)
        patches = self.reconstruct(patches).transpose(1, 2)
        if self.fold is None:
            self.fold = nn.Fold(output_size=(H, W),
                                kernel_size=self.patch_size,
                                stride=self.patch_size)
        x_reconstructed = self.fold(patches)
        return x_reconstructed

class UpCUP(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.up = nn.Sequential(
            nn.Conv2d(in_ch, out_ch * 4, kernel_size=3, padding=1),
            nn.PixelShuffle(2),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
            nn.GELU()
        )
    def forward(self, x):
        return self.up(x)

class TransUNet(nn.Module):
    def __init__(self, in_ch=1, out_ch=1, base_ch=64):
        super().__init__()
        self.enc1 = ConvBlock(in_ch, base_ch)
        self.enc2 = ConvBlock(base_ch, base_ch*2)
        self.enc3 = ConvBlock(base_ch*2, base_ch*4)
        self.pool = nn.MaxPool2d(2)
        self.bottleneck = TransformerBottleneck(base_ch*4, patch_size=16, num_layers=2, num_heads=4)
        self.up3  = UpCUP(base_ch*4, base_ch*2)
        self.dec3 = ConvBlock(base_ch*2 + base_ch*4, base_ch*2)
        self.up2  = UpCUP(base_ch*2, base_ch)
        self.dec2 = ConvBlock(base_ch + base_ch*2, base_ch)
        self.up1  = UpCUP(base_ch, base_ch//2)
        self.dec1 = ConvBlock(base_ch//2 + base_ch, base_ch//2)
        self.out_conv = nn.Conv2d(base_ch//2, out_ch, kernel_size=1)
    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        b  = self.bottleneck(self.pool(e3))
        d3 = self.dec3(torch.cat([self.up3(b),  e3], dim=1))
        d2 = self.dec2(torch.cat([self.up2(d3), e2], dim=1))
        d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))
        return self.out_conv(d1)


# ==========================
# LOAD MODEL
# ==========================
model = TransUNet().to(DEVICE)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.eval()
print(f"Loaded model from {MODEL_PATH}")


# ==========================
# INFERENCE ON REAL IMAGES
# ==========================
files = [f for f in os.listdir(TEST_DIR) if f.lower().endswith((".png",".jpg",".jpeg"))]

for i, fname in enumerate(sorted(files)):
    path = os.path.join(TEST_DIR, fname)
    img = Image.open(path).convert("L").resize(IMAGE_SIZE)
    arr = np.array(img, dtype=np.float32) / 255.0
    inp = torch.tensor(arr).unsqueeze(0).unsqueeze(0).to(DEVICE)  # (1,1,H,W)

    with torch.no_grad():
        pred = model(inp).cpu().squeeze().numpy()

    pred_img = (np.clip(pred, 0, 1) * 255).astype(np.uint8)
    out_path = os.path.join(RESULTS_DIR, f"pred_{i+1}.png")
    Image.fromarray(pred_img).save(out_path)
    print(f"Saved: {out_path}")

print(f"\n✅ All predictions saved to {RESULTS_DIR}")


Loaded model from C:\Preet\png_images_400M_670M\fk_transunet++.pth
Saved: C:\Preet\Real GPR data\mendaley\Utilities_predictions_fk\pred_1.png
Saved: C:\Preet\Real GPR data\mendaley\Utilities_predictions_fk\pred_2.png
Saved: C:\Preet\Real GPR data\mendaley\Utilities_predictions_fk\pred_3.png
Saved: C:\Preet\Real GPR data\mendaley\Utilities_predictions_fk\pred_4.png
Saved: C:\Preet\Real GPR data\mendaley\Utilities_predictions_fk\pred_5.png
Saved: C:\Preet\Real GPR data\mendaley\Utilities_predictions_fk\pred_6.png
Saved: C:\Preet\Real GPR data\mendaley\Utilities_predictions_fk\pred_7.png
Saved: C:\Preet\Real GPR data\mendaley\Utilities_predictions_fk\pred_8.png
Saved: C:\Preet\Real GPR data\mendaley\Utilities_predictions_fk\pred_9.png
Saved: C:\Preet\Real GPR data\mendaley\Utilities_predictions_fk\pred_10.png
Saved: C:\Preet\Real GPR data\mendaley\Utilities_predictions_fk\pred_11.png
Saved: C:\Preet\Real GPR data\mendaley\Utilities_predictions_fk\pred_12.png
Saved: C:\Preet\Real GPR data\