In [2]:
import os
import cv2
import numpy as np
from PIL import Image
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

# ==========================
# CONFIG
# ==========================
DATASET_DIR = r"C:\Preet\9000_paired_bscans"  # *_l.png and *_h.png paired images
IMAGE_SIZE = (256, 256)
BATCH_SIZE = 16
EPOCHS = 50
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Running on:", DEVICE)

MODEL_PATH = os.path.join(DATASET_DIR, "best_transunet_huber+ssim.pth")
RESULTS_DIR = os.path.join(DATASET_DIR, "predictions_transunet_huber+ssim")
os.makedirs(RESULTS_DIR, exist_ok=True)

# ==========================
# DATASET
# ==========================
class GPRDataset(Dataset):
    def __init__(self, x_paths, y_paths):
        self.x_paths = x_paths
        self.y_paths = y_paths

    def __len__(self):
        return len(self.x_paths)

    def __getitem__(self, idx):
        x = np.array(Image.open(self.x_paths[idx]).resize(IMAGE_SIZE), dtype=np.float32) / 255.0
        y = np.array(Image.open(self.y_paths[idx]).resize(IMAGE_SIZE), dtype=np.float32) / 255.0
        x = torch.tensor(x).unsqueeze(0)  # (1,H,W)
        y = torch.tensor(y).unsqueeze(0)
        return x, y

def load_data(dataset_dir):
    low_paths, high_paths = [], []
    for file in os.listdir(dataset_dir):
        if file.endswith("_l.png"):
            low_path = os.path.join(dataset_dir, file)
            high_path = os.path.join(dataset_dir, file.replace("_l.png", "_h.png"))
            if os.path.exists(high_path):
                low_paths.append(low_path)
                high_paths.append(high_path)
    return low_paths, high_paths

all_x, all_y = load_data(DATASET_DIR)

# Split 70/15/15
train_x, temp_x, train_y, temp_y = train_test_split(all_x, all_y, test_size=0.30, random_state=42)
val_x, test_x, val_y, test_y = train_test_split(temp_x, temp_y, test_size=0.50, random_state=42)

print(f"Train: {len(train_x)}, Val: {len(val_x)}, Test: {len(test_x)}")

train_loader = DataLoader(GPRDataset(train_x, train_y), batch_size=BATCH_SIZE, shuffle=True)
val_loader   = DataLoader(GPRDataset(val_x, val_y), batch_size=BATCH_SIZE, shuffle=False)
test_loader  = DataLoader(GPRDataset(test_x, test_y), batch_size=1, shuffle=False)

# ==========================
# MODEL BLOCKS
# ==========================
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.conv(x)

# class AttentionGate(nn.Module):
#     def __init__(self, g_ch, x_ch, inter_ch):
#         super().__init__()
#         self.W_g = nn.Conv2d(g_ch, inter_ch, 1)
#         self.W_x = nn.Conv2d(x_ch, inter_ch, 1)
#         self.psi = nn.Conv2d(inter_ch, 1, 1)
#         self.relu = nn.ReLU(inplace=True)
#         self.sigmoid = nn.Sigmoid()

#     def forward(self, g, x):
#         g1 = self.W_g(g)
#         x1 = self.W_x(x)
#         psi = self.relu(g1 + x1)
#         psi = self.sigmoid(self.psi(psi))
#         return x * psi

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads=4, ff_dim=512, dropout=0.1):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, embed_dim)
        )
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.drop = nn.Dropout(dropout)

    def forward(self, x):
        attn_out, _ = self.attn(x, x, x)
        x = self.norm1(x + self.drop(attn_out))
        ff_out = self.ff(x)
        x = self.norm2(x + self.drop(ff_out))
        return x

class TransformerBottleneck(nn.Module):
    def __init__(self, in_ch, patch_size=16, num_layers=2, num_heads=4):
        super().__init__()
        self.patch_size = patch_size
        self.embed_dim = in_ch

        self.flatten = nn.Unfold(kernel_size=patch_size, stride=patch_size)
        self.project = nn.Linear(patch_size * patch_size * in_ch, in_ch)

        self.transformer = nn.Sequential(*[
            TransformerBlock(embed_dim=in_ch, num_heads=num_heads) for _ in range(num_layers)
        ])

        self.reconstruct = nn.Linear(in_ch, patch_size * patch_size * in_ch)
        self.fold = None  # initialized at runtime

    def forward(self, x):
        B, C, H, W = x.shape
        patches = self.flatten(x).transpose(1, 2)  # (B, N, patch_dim)
        patches = self.project(patches)

        patches = self.transformer(patches)

        patches = self.reconstruct(patches).transpose(1, 2)
        if self.fold is None:
            self.fold = nn.Fold(output_size=(H, W),
                                kernel_size=self.patch_size,
                                stride=self.patch_size)
        x_reconstructed = self.fold(patches)
        return x_reconstructed


class TransUNet(nn.Module):
    def __init__(self, in_ch=1, out_ch=1, base_ch=64):
        super().__init__()
        # ================= Encoder =================
        self.enc1 = ConvBlock(in_ch, base_ch)          # 1 -> 64
        self.enc2 = ConvBlock(base_ch, base_ch*2)     # 64 -> 128
        self.enc3 = ConvBlock(base_ch*2, base_ch*4)   # 128 -> 256
        self.pool = nn.MaxPool2d(2)

        # ================= Bottleneck =================
        self.bottleneck = TransformerBottleneck(base_ch*4, patch_size=16, num_layers=2, num_heads=4)

        # ================= Decoder =================
        # dec3 input: bottleneck + skip e3 = 256 + 256 = 512 channels
        self.dec3 = ConvBlock(base_ch*4 + base_ch*4, base_ch*2)  # 512 -> 128
        # dec2 input: dec3 output + skip e2 = 128 + 128 = 256 channels
        self.dec2 = ConvBlock(base_ch*2 + base_ch*2, base_ch)    # 256 -> 64

        # final conv
        self.out_conv = nn.Conv2d(base_ch, out_ch, kernel_size=1)  # 64 -> 1

    def forward(self, x):
        # ---------------- Encoder ----------------
        e1 = self.enc1(x)                  # [B,64,H,W]
        e2 = self.enc2(self.pool(e1))      # [B,128,H/2,W/2]
        e3 = self.enc3(self.pool(e2))      # [B,256,H/4,W/4]

        # ---------------- Bottleneck ----------------
        b = self.bottleneck(self.pool(e3)) # [B,256,H/8,W/8]

        # ---------------- Decoder ----------------
        d3 = F.interpolate(b, size=e3.shape[2:], mode='bilinear', align_corners=False)
        d3 = self.dec3(torch.cat([d3, e3], dim=1))  # [B,128,H/4,W/4]

        d2 = F.interpolate(d3, size=e2.shape[2:], mode='bilinear', align_corners=False)
        d2 = self.dec2(torch.cat([d2, e2], dim=1))  # [B,64,H/2,W/2]

        d1 = F.interpolate(d2, size=e1.shape[2:], mode='bilinear', align_corners=False)
        out = self.out_conv(d1)                        # [B,1,H,W]

        return out




# ==========================
# TRAINING
# ==========================
model = TransUNet().to(DEVICE)


from pytorch_msssim import ssim  # pip install pytorch-msssim

alpha = 0.8
beta  = 0.2

criterion = nn.HuberLoss(delta=1.0)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

best_val_loss = float("inf")

for epoch in range(EPOCHS):
    model.train()
    train_loss = 0.0
    for x, y in train_loader:
        x, y = x.to(DEVICE), y.to(DEVICE)
        optimizer.zero_grad()
        preds = model(x)

        # Huber Loss
        huber_loss = criterion(preds, y)

        # SSIM Loss (1 - SSIM)
        ssim_loss = 1 - ssim(preds, y, data_range=1.0, size_average=True)  # assumes preds/y in [0,1]

        # Combined Loss
        loss = alpha * huber_loss + beta * ssim_loss
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * x.size(0)
    train_loss /= len(train_loader.dataset)

    # ---- Validation ----
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for x, y in val_loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            preds = model(x)

            huber_loss = criterion(preds, y)
            ssim_loss = 1 - ssim(preds, y, data_range=1.0, size_average=True)
            loss = alpha * huber_loss + beta * ssim_loss

            val_loss += loss.item() * x.size(0)
    val_loss /= len(val_loader.dataset)

    print(f"Epoch [{epoch+1}/{EPOCHS}] Train Loss: {train_loss:.6f}, Val Loss: {val_loss:.6f}")

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), MODEL_PATH)
        print(f"  ✅ Saved Best Model at Epoch {epoch+1}")


# ==========================
# INFERENCE
# ==========================
print("\nRunning inference on test set...")
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.eval()

for i, (x, y) in enumerate(test_loader):
    x = x.to(DEVICE)
    with torch.no_grad():
        pred = model(x).cpu().squeeze(0).squeeze(0).numpy()
    pred_img = (pred * 255.0).clip(0, 255).astype(np.uint8)
    Image.fromarray(pred_img).save(os.path.join(RESULTS_DIR, f"pred_{i+1}.png"))

print(f"Predictions saved in {RESULTS_DIR}")

Running on: cuda
Train: 6300, Val: 1350, Test: 1350
Epoch [1/50] Train Loss: 0.012190, Val Loss: 0.000855
  ✅ Saved Best Model at Epoch 1
Epoch [2/50] Train Loss: 0.000746, Val Loss: 0.000649
  ✅ Saved Best Model at Epoch 2
Epoch [3/50] Train Loss: 0.000632, Val Loss: 0.000582
  ✅ Saved Best Model at Epoch 3
Epoch [4/50] Train Loss: 0.000566, Val Loss: 0.000547
  ✅ Saved Best Model at Epoch 4
Epoch [5/50] Train Loss: 0.000522, Val Loss: 0.000488
  ✅ Saved Best Model at Epoch 5
Epoch [6/50] Train Loss: 0.000483, Val Loss: 0.000465
  ✅ Saved Best Model at Epoch 6
Epoch [7/50] Train Loss: 0.000460, Val Loss: 0.000457
  ✅ Saved Best Model at Epoch 7
Epoch [8/50] Train Loss: 0.000440, Val Loss: 0.000422
  ✅ Saved Best Model at Epoch 8
Epoch [9/50] Train Loss: 0.000425, Val Loss: 0.000446
Epoch [10/50] Train Loss: 0.000414, Val Loss: 0.000397
  ✅ Saved Best Model at Epoch 10
Epoch [11/50] Train Loss: 0.000398, Val Loss: 0.000395
  ✅ Saved Best Model at Epoch 11
Epoch [12/50] Train Loss: 0.00

In [3]:
import os
import cv2
import numpy as np
from skimage.metrics import structural_similarity as ssim
from math import log10
import shutil

# ========================== PATHS ==========================
RESULTS_DIR = r"C:\Preet\9000_paired_bscans\predictions_transunet_huber+ssim"  # predictions from TransUNet
GT_TEST_DIR = r"C:\Preet\9000_paired_bscans\ground_truth_test_transunet_huber+ssim"

# Copy ground truth test images to a dedicated folder
os.makedirs(GT_TEST_DIR, exist_ok=True)
for f in test_y:  # test_y comes from your train/val/test split
    shutil.copy(f, GT_TEST_DIR)

# ========================== FUNCTIONS ==========================
def calculate_psnr(img1, img2):
    mse = np.mean((img1 - img2) ** 2)
    if mse == 0:
        return float('inf')
    return 20 * log10(255.0 / np.sqrt(mse))

def calculate_ssim(img1, img2):
    return ssim(img1, img2, data_range=255)

# ========================== MAIN EVALUATION ==========================
psnr_values, ssim_values = [], []

pred_files = sorted([f for f in os.listdir(RESULTS_DIR) if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
gt_files   = sorted([f for f in os.listdir(GT_TEST_DIR) if f.lower().endswith(('.png', '.jpg', '.jpeg'))])

num_pairs = min(len(pred_files), len(gt_files))
if num_pairs == 0:
    print("[Error] No matching image files found in both directories.")
else:
    if len(pred_files) != len(gt_files):
        print(f"[Warning] Different number of images: Predictions={len(pred_files)}, Ground Truth={len(gt_files)}")
        print(f"Evaluating only first {num_pairs} matched pairs.")

    for i in range(num_pairs):
        pred_path = os.path.join(RESULTS_DIR, pred_files[i])
        gt_path   = os.path.join(GT_TEST_DIR, gt_files[i])

        pred_img = cv2.imread(pred_path, cv2.IMREAD_GRAYSCALE)
        gt_img   = cv2.imread(gt_path, cv2.IMREAD_GRAYSCALE)

        if pred_img is None or gt_img is None:
            print(f"[Error] Could not load: {pred_files[i]} or {gt_files[i]}")
            continue

        if pred_img.shape != gt_img.shape:
            pred_img = cv2.resize(pred_img, (gt_img.shape[1], gt_img.shape[0]))

        psnr_values.append(calculate_psnr(pred_img, gt_img))
        ssim_values.append(calculate_ssim(pred_img, gt_img))

    if psnr_values and ssim_values:
        print(f"\n---- Test Set Evaluation ----")
        print(f"SSIM: avg={np.mean(ssim_values):.4f}, min={np.min(ssim_values):.4f}, max={np.max(ssim_values):.4f}")
        print(f"PSNR: avg={np.mean(psnr_values):.2f} dB, min={np.min(psnr_values):.2f} dB, max={np.max(psnr_values):.2f} dB")



---- Test Set Evaluation ----
SSIM: avg=0.8929, min=0.7065, max=0.9982
PSNR: avg=35.29 dB, min=28.64 dB, max=48.84 dB


In [4]:
import os
import re
import numpy as np
from PIL import Image
import cv2
from math import log10
from skimage.metrics import structural_similarity as ssim

import torch
import torch.nn as nn
import torch.nn.functional as F

# =========================
# CONFIG
# =========================
FREQ_DIRS = [
    r"C:\Preet\validation dataset\png_images_650M_1083M",
    r"C:\Preet\validation dataset\png_images_700M_1167M",
    r"C:\Preet\validation dataset\png_images_800M_1333M",
    r"C:\Preet\validation dataset\png_images_850M_1416M",
    r"C:\Preet\validation dataset\png_images_900M_1500M",
]

MODEL_PATH = r"C:\Preet\9000_paired_bscans\best_transunet_huber+ssim.pth"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMAGE_SIZE = (256, 256)

# =========================
# MODEL BLOCKS
# =========================
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True)
        )
    def forward(self, x): return self.conv(x)

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads=4, ff_dim=512, dropout=0.1):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, embed_dim)
        )
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.drop = nn.Dropout(dropout)
    def forward(self, x):
        attn_out, _ = self.attn(x, x, x)
        x = self.norm1(x + self.drop(attn_out))
        ff_out = self.ff(x)
        x = self.norm2(x + self.drop(ff_out))
        return x

class TransformerBottleneck(nn.Module):
    def __init__(self, in_ch, patch_size=16, num_layers=2, num_heads=4):
        super().__init__()
        self.patch_size = patch_size
        self.embed_dim = in_ch
        self.flatten = nn.Unfold(kernel_size=patch_size, stride=patch_size)
        self.project = nn.Linear(patch_size*patch_size*in_ch, in_ch)
        self.transformer = nn.Sequential(*[TransformerBlock(embed_dim=in_ch, num_heads=num_heads) for _ in range(num_layers)])
        self.reconstruct = nn.Linear(in_ch, patch_size*patch_size*in_ch)
        self.fold = None
    def forward(self, x):
        B,C,H,W = x.shape
        patches = self.flatten(x).transpose(1,2)
        patches = self.project(patches)
        patches = self.transformer(patches)
        patches = self.reconstruct(patches).transpose(1,2)
        if self.fold is None:
            self.fold = nn.Fold(output_size=(H,W), kernel_size=self.patch_size, stride=self.patch_size)
        x_reconstructed = self.fold(patches)
        return x_reconstructed

class TransUNet(nn.Module):
    def __init__(self, in_ch=1, out_ch=1, base_ch=64):
        super().__init__()
        self.enc1 = ConvBlock(in_ch, base_ch)
        self.enc2 = ConvBlock(base_ch, base_ch*2)
        self.enc3 = ConvBlock(base_ch*2, base_ch*4)
        self.pool = nn.MaxPool2d(2)
        self.bottleneck = TransformerBottleneck(base_ch*4, patch_size=16, num_layers=2, num_heads=4)
        self.dec3 = ConvBlock(base_ch*4 + base_ch*4, base_ch*2)
        self.dec2 = ConvBlock(base_ch*2 + base_ch*2, base_ch)
        self.out_conv = nn.Conv2d(base_ch, out_ch, kernel_size=1)
    def forward(self, x):
        e1 = self.enc1(x); e2 = self.enc2(self.pool(e1)); e3 = self.enc3(self.pool(e2))
        b = self.bottleneck(self.pool(e3))
        d3 = F.interpolate(b, size=e3.shape[2:], mode='bilinear', align_corners=False)
        d3 = self.dec3(torch.cat([d3, e3], dim=1))
        d2 = F.interpolate(d3, size=e2.shape[2:], mode='bilinear', align_corners=False)
        d2 = self.dec2(torch.cat([d2, e2], dim=1))
        d1 = F.interpolate(d2, size=e1.shape[2:], mode='bilinear', align_corners=False)
        return self.out_conv(d1)

# =========================
# HELPERS
# =========================
def calculate_psnr(img1, img2):
    mse = np.mean((img1 - img2) ** 2)
    if mse == 0: return float('inf')
    return 20 * log10(255.0 / np.sqrt(mse))

def calculate_ssim(img1, img2):
    return ssim(img1, img2, data_range=255)

def numerical_sort(files):
    return sorted(files, key=lambda f: int(re.search(r'\d+', f).group()))

# =========================
# LOAD MODEL
# =========================
model = TransUNet().to(DEVICE)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.eval()

# =========================
# INFERENCE LOOP
# =========================
for freq_dir in FREQ_DIRS:
    print(f"\n---- Evaluating {os.path.basename(freq_dir)} ----")
    pred_dir = os.path.join(freq_dir, "predictions_transunet_huber+ssim")
    os.makedirs(pred_dir, exist_ok=True)

    low_files = numerical_sort([f for f in os.listdir(freq_dir) if f.endswith("_l.png")])
    high_files = numerical_sort([f for f in os.listdir(freq_dir) if f.endswith("_h.png")])

    psnr_values, ssim_values = [], []

    for i in range(len(low_files)):
        low_path = os.path.join(freq_dir, low_files[i])
        high_path = os.path.join(freq_dir, high_files[i])

        # Load LR
        lr_img = np.array(Image.open(low_path).resize(IMAGE_SIZE), dtype=np.float32)/255.0
        lr_tensor = torch.tensor(lr_img).unsqueeze(0).unsqueeze(0).to(DEVICE)

        # Predict
        with torch.no_grad():
            pred = model(lr_tensor).cpu().numpy()

        # remove batch/channel dims and scale to 0-255
        pred_img = np.squeeze(pred)
        pred_img = (pred_img*255.0).clip(0,255).astype(np.uint8)
        pred_img = cv2.normalize(pred_img, None, 0, 255, cv2.NORM_MINMAX)

        # save prediction
        pred_path = os.path.join(pred_dir, f"pred_{i+1}.png")
        Image.fromarray(pred_img).save(pred_path)

        # Load GT
        gt_img = cv2.imread(high_path, cv2.IMREAD_GRAYSCALE)
        if pred_img.shape != gt_img.shape:
            pred_img = cv2.resize(pred_img, (gt_img.shape[1], gt_img.shape[0]))

        # Metrics
        psnr_values.append(calculate_psnr(pred_img, gt_img))
        ssim_values.append(calculate_ssim(pred_img, gt_img))

    if psnr_values and ssim_values:
        print(f"SSIM: avg={np.mean(ssim_values):.4f}, min={np.min(ssim_values):.4f}, max={np.max(ssim_values):.4f}")
        print(f"PSNR: avg={np.mean(psnr_values):.2f} dB, min={np.min(psnr_values):.2f} dB, max={np.max(psnr_values):.2f} dB")
    else:
        print("[Error] No valid pairs processed.")


---- Evaluating png_images_650M_1083M ----
SSIM: avg=0.9559, min=0.9274, max=0.9672
PSNR: avg=35.56 dB, min=32.54 dB, max=36.83 dB

---- Evaluating png_images_700M_1167M ----
SSIM: avg=0.9811, min=0.7581, max=0.9856
PSNR: avg=36.55 dB, min=29.39 dB, max=37.53 dB

---- Evaluating png_images_800M_1333M ----
SSIM: avg=0.9906, min=0.9728, max=0.9941
PSNR: avg=38.89 dB, min=34.09 dB, max=40.08 dB

---- Evaluating png_images_850M_1416M ----
SSIM: avg=0.9650, min=0.9342, max=0.9752
PSNR: avg=34.06 dB, min=31.07 dB, max=35.52 dB

---- Evaluating png_images_900M_1500M ----
SSIM: avg=0.9344, min=0.8776, max=0.9532
PSNR: avg=33.95 dB, min=30.09 dB, max=35.40 dB


In [5]:
import os
import re
import numpy as np
from PIL import Image
import cv2
from math import log10
from skimage.metrics import structural_similarity as ssim

import torch
import torch.nn as nn
import torch.nn.functional as F

# =========================
# CONFIG
# =========================
DATASET_DIR = r"C:\Preet\validation dataset\png_images_650M_1083M_1800M"
MODEL_PATH = r"C:\Preet\9000_paired_bscans\best_transunet_huber+ssim.pth"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMAGE_SIZE = (256, 256)

# =========================
# MODEL BLOCKS
# =========================
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True)
        )
    def forward(self, x): return self.conv(x)

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads=4, ff_dim=512, dropout=0.1):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, embed_dim)
        )
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.drop = nn.Dropout(dropout)
    def forward(self, x):
        attn_out, _ = self.attn(x, x, x)
        x = self.norm1(x + self.drop(attn_out))
        ff_out = self.ff(x)
        x = self.norm2(x + self.drop(ff_out))
        return x

class TransformerBottleneck(nn.Module):
    def __init__(self, in_ch, patch_size=16, num_layers=2, num_heads=4):
        super().__init__()
        self.patch_size = patch_size
        self.embed_dim = in_ch
        self.flatten = nn.Unfold(kernel_size=patch_size, stride=patch_size)
        self.project = nn.Linear(patch_size*patch_size*in_ch, in_ch)
        self.transformer = nn.Sequential(*[TransformerBlock(embed_dim=in_ch, num_heads=num_heads) for _ in range(num_layers)])
        self.reconstruct = nn.Linear(in_ch, patch_size*patch_size*in_ch)
        self.fold = None
    def forward(self, x):
        B,C,H,W = x.shape
        patches = self.flatten(x).transpose(1,2)
        patches = self.project(patches)
        patches = self.transformer(patches)
        patches = self.reconstruct(patches).transpose(1,2)
        if self.fold is None:
            self.fold = nn.Fold(output_size=(H,W), kernel_size=self.patch_size, stride=self.patch_size)
        x_reconstructed = self.fold(patches)
        return x_reconstructed

class TransUNet(nn.Module):
    def __init__(self, in_ch=1, out_ch=1, base_ch=64):
        super().__init__()
        self.enc1 = ConvBlock(in_ch, base_ch)
        self.enc2 = ConvBlock(base_ch, base_ch*2)
        self.enc3 = ConvBlock(base_ch*2, base_ch*4)
        self.pool = nn.MaxPool2d(2)
        self.bottleneck = TransformerBottleneck(base_ch*4, patch_size=16, num_layers=2, num_heads=4)
        self.dec3 = ConvBlock(base_ch*4 + base_ch*4, base_ch*2)
        self.dec2 = ConvBlock(base_ch*2 + base_ch*2, base_ch)
        self.out_conv = nn.Conv2d(base_ch, out_ch, kernel_size=1)
    def forward(self, x):
        e1 = self.enc1(x); e2 = self.enc2(self.pool(e1)); e3 = self.enc3(self.pool(e2))
        b = self.bottleneck(self.pool(e3))
        d3 = F.interpolate(b, size=e3.shape[2:], mode='bilinear', align_corners=False)
        d3 = self.dec3(torch.cat([d3, e3], dim=1))
        d2 = F.interpolate(d3, size=e2.shape[2:], mode='bilinear', align_corners=False)
        d2 = self.dec2(torch.cat([d2, e2], dim=1))
        d1 = F.interpolate(d2, size=e1.shape[2:], mode='bilinear', align_corners=False)
        return self.out_conv(d1)

# =========================
# HELPERS
# =========================
def calculate_psnr(img1, img2):
    mse = np.mean((img1 - img2) ** 2)
    if mse == 0: return float('inf')
    return 20 * log10(255.0 / np.sqrt(mse))

def calculate_ssim(img1, img2):
    return ssim(img1, img2, data_range=255)

def numerical_sort(files):
    return sorted(files, key=lambda f: int(re.search(r'\d+', f).group()))

def run_model(img, model):
    tensor = torch.tensor(img, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        pred = model(tensor).cpu().numpy()
    pred = np.squeeze(pred)
    pred = (pred * 255.0).clip(0, 255).astype(np.uint8)
    return cv2.normalize(pred, None, 0, 255, cv2.NORM_MINMAX)

# =========================
# LOAD MODEL
# =========================
model = TransUNet().to(DEVICE)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.eval()

# =========================
# INFERENCE 2-STAGE
# =========================
files_650 = numerical_sort([f for f in os.listdir(DATASET_DIR) if f.endswith("650_bscan.png")])
files_1083 = numerical_sort([f for f in os.listdir(DATASET_DIR) if f.endswith("1083_bscan.png")])
files_1800 = numerical_sort([f for f in os.listdir(DATASET_DIR) if f.endswith("1800_bscan.png")])

psnr_stage1, ssim_stage1 = [], []
psnr_stage2, ssim_stage2 = [], []

pred_dir1 = os.path.join(DATASET_DIR, "predictions_stage1_transunet_huber+ssim")
pred_dir2 = os.path.join(DATASET_DIR, "predictions_stage2_transunet_huber+ssim")
os.makedirs(pred_dir1, exist_ok=True)
os.makedirs(pred_dir2, exist_ok=True)

for i in range(len(files_650)):
    # ---- Stage 1: 650 -> pred -> compare with 1083 ----
    lr_img = np.array(Image.open(os.path.join(DATASET_DIR, files_650[i])).resize(IMAGE_SIZE), dtype=np.float32) / 255.0
    gt_1083 = cv2.imread(os.path.join(DATASET_DIR, files_1083[i]), cv2.IMREAD_GRAYSCALE)

    pred_1083 = run_model(lr_img, model)
    Image.fromarray(pred_1083).save(os.path.join(pred_dir1, f"pred1_{i+1}.png"))

    if pred_1083.shape != gt_1083.shape:
        pred_1083 = cv2.resize(pred_1083, (gt_1083.shape[1], gt_1083.shape[0]))

    psnr_stage1.append(calculate_psnr(pred_1083, gt_1083))
    ssim_stage1.append(calculate_ssim(pred_1083, gt_1083))

    # ---- Stage 2: pred_1083 -> pred -> compare with 1800 ----
    gt_1800 = cv2.imread(os.path.join(DATASET_DIR, files_1800[i]), cv2.IMREAD_GRAYSCALE)
    pred_1083_resized = cv2.resize(pred_1083, IMAGE_SIZE).astype(np.float32)/255.0
    pred_1800 = run_model(pred_1083_resized, model)
    Image.fromarray(pred_1800).save(os.path.join(pred_dir2, f"pred2_{i+1}.png"))

    if pred_1800.shape != gt_1800.shape:
        pred_1800 = cv2.resize(pred_1800, (gt_1800.shape[1], gt_1800.shape[0]))

    psnr_stage2.append(calculate_psnr(pred_1800, gt_1800))
    ssim_stage2.append(calculate_ssim(pred_1800, gt_1800))

# =========================
# RESULTS
# =========================
print("\n---- Stage 1: 650 → pred → compare with 1083 ----")
print(f"SSIM: avg={np.mean(ssim_stage1):.4f}, min={np.min(ssim_stage1):.4f}, max={np.max(ssim_stage1):.4f}")
print(f"PSNR: avg={np.mean(psnr_stage1):.2f} dB, min={np.min(psnr_stage1):.2f} dB, max={np.max(psnr_stage1):.2f} dB")

print("\n---- Stage 2: pred_1083 → pred → compare with 1800 ----")
print(f"SSIM: avg={np.mean(ssim_stage2):.4f}, min={np.min(ssim_stage2):.4f}, max={np.max(ssim_stage2):.4f}")
print(f"PSNR: avg={np.mean(psnr_stage2):.2f} dB, min={np.min(psnr_stage2):.2f} dB, max={np.max(psnr_stage2):.2f} dB")



---- Stage 1: 650 → pred → compare with 1083 ----
SSIM: avg=0.9541, min=0.9275, max=0.9674
PSNR: avg=35.57 dB, min=32.63 dB, max=36.81 dB

---- Stage 2: pred_1083 → pred → compare with 1800 ----
SSIM: avg=0.9268, min=0.8927, max=0.9463
PSNR: avg=30.13 dB, min=28.69 dB, max=30.88 dB


In [1]:
import os
import re
import numpy as np
from PIL import Image
import cv2
from math import log10
from skimage.metrics import structural_similarity as ssim

import torch
import torch.nn as nn
import torch.nn.functional as F

# =========================
# CONFIG
# =========================
DATASET_DIR = r"C:\Preet\validation dataset\png_images_900M_1500M_2500M"
MODEL_PATH = r"C:\Preet\9000_paired_bscans\best_transunet_huber+ssim.pth"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMAGE_SIZE = (256, 256)

# =========================
# MODEL BLOCKS
# =========================
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True)
        )
    def forward(self, x): return self.conv(x)

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads=4, ff_dim=512, dropout=0.1):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, embed_dim)
        )
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.drop = nn.Dropout(dropout)
    def forward(self, x):
        attn_out, _ = self.attn(x, x, x)
        x = self.norm1(x + self.drop(attn_out))
        ff_out = self.ff(x)
        x = self.norm2(x + self.drop(ff_out))
        return x

class TransformerBottleneck(nn.Module):
    def __init__(self, in_ch, patch_size=16, num_layers=2, num_heads=4):
        super().__init__()
        self.patch_size = patch_size
        self.embed_dim = in_ch
        self.flatten = nn.Unfold(kernel_size=patch_size, stride=patch_size)
        self.project = nn.Linear(patch_size*patch_size*in_ch, in_ch)
        self.transformer = nn.Sequential(*[TransformerBlock(embed_dim=in_ch, num_heads=num_heads) for _ in range(num_layers)])
        self.reconstruct = nn.Linear(in_ch, patch_size*patch_size*in_ch)
        self.fold = None
    def forward(self, x):
        B,C,H,W = x.shape
        patches = self.flatten(x).transpose(1,2)
        patches = self.project(patches)
        patches = self.transformer(patches)
        patches = self.reconstruct(patches).transpose(1,2)
        if self.fold is None:
            self.fold = nn.Fold(output_size=(H,W), kernel_size=self.patch_size, stride=self.patch_size)
        x_reconstructed = self.fold(patches)
        return x_reconstructed

class TransUNet(nn.Module):
    def __init__(self, in_ch=1, out_ch=1, base_ch=64):
        super().__init__()
        self.enc1 = ConvBlock(in_ch, base_ch)
        self.enc2 = ConvBlock(base_ch, base_ch*2)
        self.enc3 = ConvBlock(base_ch*2, base_ch*4)
        self.pool = nn.MaxPool2d(2)
        self.bottleneck = TransformerBottleneck(base_ch*4, patch_size=16, num_layers=2, num_heads=4)
        self.dec3 = ConvBlock(base_ch*4 + base_ch*4, base_ch*2)
        self.dec2 = ConvBlock(base_ch*2 + base_ch*2, base_ch)
        self.out_conv = nn.Conv2d(base_ch, out_ch, kernel_size=1)
    def forward(self, x):
        e1 = self.enc1(x); e2 = self.enc2(self.pool(e1)); e3 = self.enc3(self.pool(e2))
        b = self.bottleneck(self.pool(e3))
        d3 = F.interpolate(b, size=e3.shape[2:], mode='bilinear', align_corners=False)
        d3 = self.dec3(torch.cat([d3, e3], dim=1))
        d2 = F.interpolate(d3, size=e2.shape[2:], mode='bilinear', align_corners=False)
        d2 = self.dec2(torch.cat([d2, e2], dim=1))
        d1 = F.interpolate(d2, size=e1.shape[2:], mode='bilinear', align_corners=False)
        return self.out_conv(d1)

# =========================
# HELPERS
# =========================
def calculate_psnr(img1, img2):
    mse = np.mean((img1 - img2) ** 2)
    if mse == 0: return float('inf')
    return 20 * log10(255.0 / np.sqrt(mse))

def calculate_ssim(img1, img2):
    return ssim(img1, img2, data_range=255)

def numerical_sort(files):
    return sorted(files, key=lambda f: int(re.search(r'\d+', f).group()))

def run_model(img, model):
    tensor = torch.tensor(img, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        pred = model(tensor).cpu().numpy()
    pred = np.squeeze(pred)
    pred = (pred * 255.0).clip(0, 255).astype(np.uint8)
    return cv2.normalize(pred, None, 0, 255, cv2.NORM_MINMAX)

# =========================
# LOAD MODEL
# =========================
model = TransUNet().to(DEVICE)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.eval()

# =========================
# INFERENCE 2-STAGE
# =========================
files_650 = numerical_sort([f for f in os.listdir(DATASET_DIR) if f.endswith("900_bscan.png")])
files_1083 = numerical_sort([f for f in os.listdir(DATASET_DIR) if f.endswith("1500_bscan.png")])
files_1800 = numerical_sort([f for f in os.listdir(DATASET_DIR) if f.endswith("2500_bscan.png")])

psnr_stage1, ssim_stage1 = [], []
psnr_stage2, ssim_stage2 = [], []

pred_dir1 = os.path.join(DATASET_DIR, "predictions_stage1_transunet_huber+ssim")
pred_dir2 = os.path.join(DATASET_DIR, "predictions_stage2_transunet_huber+ssim")
os.makedirs(pred_dir1, exist_ok=True)
os.makedirs(pred_dir2, exist_ok=True)

for i in range(len(files_650)):
    # ---- Stage 1: 650 -> pred -> compare with 1083 ----
    lr_img = np.array(Image.open(os.path.join(DATASET_DIR, files_650[i])).resize(IMAGE_SIZE), dtype=np.float32) / 255.0
    gt_1083 = cv2.imread(os.path.join(DATASET_DIR, files_1083[i]), cv2.IMREAD_GRAYSCALE)

    pred_1083 = run_model(lr_img, model)
    Image.fromarray(pred_1083).save(os.path.join(pred_dir1, f"pred1_{i+1}.png"))

    if pred_1083.shape != gt_1083.shape:
        pred_1083 = cv2.resize(pred_1083, (gt_1083.shape[1], gt_1083.shape[0]))

    psnr_stage1.append(calculate_psnr(pred_1083, gt_1083))
    ssim_stage1.append(calculate_ssim(pred_1083, gt_1083))

    # ---- Stage 2: pred_1083 -> pred -> compare with 1800 ----
    gt_1800 = cv2.imread(os.path.join(DATASET_DIR, files_1800[i]), cv2.IMREAD_GRAYSCALE)
    pred_1083_resized = cv2.resize(pred_1083, IMAGE_SIZE).astype(np.float32)/255.0
    pred_1800 = run_model(pred_1083_resized, model)
    Image.fromarray(pred_1800).save(os.path.join(pred_dir2, f"pred2_{i+1}.png"))

    if pred_1800.shape != gt_1800.shape:
        pred_1800 = cv2.resize(pred_1800, (gt_1800.shape[1], gt_1800.shape[0]))

    psnr_stage2.append(calculate_psnr(pred_1800, gt_1800))
    ssim_stage2.append(calculate_ssim(pred_1800, gt_1800))

# =========================
# RESULTS
# =========================
print("\n---- Stage 1: 900 → pred → compare with 1500 ----")
print(f"SSIM: avg={np.mean(ssim_stage1):.4f}, min={np.min(ssim_stage1):.4f}, max={np.max(ssim_stage1):.4f}")
print(f"PSNR: avg={np.mean(psnr_stage1):.2f} dB, min={np.min(psnr_stage1):.2f} dB, max={np.max(psnr_stage1):.2f} dB")

print("\n---- Stage 2: pred_1500 → pred → compare with 2500 ----")
print(f"SSIM: avg={np.mean(ssim_stage2):.4f}, min={np.min(ssim_stage2):.4f}, max={np.max(ssim_stage2):.4f}")
print(f"PSNR: avg={np.mean(psnr_stage2):.2f} dB, min={np.min(psnr_stage2):.2f} dB, max={np.max(psnr_stage2):.2f} dB")



---- Stage 1: 900 → pred → compare with 1500 ----
SSIM: avg=0.7192, min=0.5828, max=0.8514
PSNR: avg=30.30 dB, min=29.25 dB, max=31.04 dB

---- Stage 2: pred_1500 → pred → compare with 2500 ----
SSIM: avg=0.7269, min=0.6029, max=0.8504
PSNR: avg=27.91 dB, min=27.40 dB, max=28.67 dB


---------------

----------------

----------------

--------------

In [1]:
import os
import cv2
import numpy as np
from PIL import Image
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

# ==========================
# CONFIG
# ==========================
DATASET_DIR = r"C:\Preet\4000_paired_bscans"  # *_l.png and *_h.png paired images
IMAGE_SIZE = (256, 256)
BATCH_SIZE = 16
EPOCHS = 75
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Running on:", DEVICE)

MODEL_PATH = os.path.join(DATASET_DIR, "best_transunet_huber+ssim.pth")
RESULTS_DIR = os.path.join(DATASET_DIR, "predictions_transunet_huber+ssim")
os.makedirs(RESULTS_DIR, exist_ok=True)

# ==========================
# DATASET
# ==========================
class GPRDataset(Dataset):
    def __init__(self, x_paths, y_paths):
        self.x_paths = x_paths
        self.y_paths = y_paths

    def __len__(self):
        return len(self.x_paths)

    def __getitem__(self, idx):
        x = np.array(Image.open(self.x_paths[idx]).resize(IMAGE_SIZE), dtype=np.float32) / 255.0
        y = np.array(Image.open(self.y_paths[idx]).resize(IMAGE_SIZE), dtype=np.float32) / 255.0
        x = torch.tensor(x).unsqueeze(0)  # (1,H,W)
        y = torch.tensor(y).unsqueeze(0)
        return x, y

def load_data(dataset_dir):
    low_paths, high_paths = [], []
    for file in os.listdir(dataset_dir):
        if file.endswith("_l.png"):
            low_path = os.path.join(dataset_dir, file)
            high_path = os.path.join(dataset_dir, file.replace("_l.png", "_h.png"))
            if os.path.exists(high_path):
                low_paths.append(low_path)
                high_paths.append(high_path)
    return low_paths, high_paths

all_x, all_y = load_data(DATASET_DIR)

# Split 70/15/15
train_x, temp_x, train_y, temp_y = train_test_split(all_x, all_y, test_size=0.30, random_state=42)
val_x, test_x, val_y, test_y = train_test_split(temp_x, temp_y, test_size=0.50, random_state=42)

print(f"Train: {len(train_x)}, Val: {len(val_x)}, Test: {len(test_x)}")

train_loader = DataLoader(GPRDataset(train_x, train_y), batch_size=BATCH_SIZE, shuffle=True)
val_loader   = DataLoader(GPRDataset(val_x, val_y), batch_size=BATCH_SIZE, shuffle=False)
test_loader  = DataLoader(GPRDataset(test_x, test_y), batch_size=1, shuffle=False)

# ==========================
# MODEL BLOCKS
# ==========================
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.conv(x)

# class AttentionGate(nn.Module):
#     def __init__(self, g_ch, x_ch, inter_ch):
#         super().__init__()
#         self.W_g = nn.Conv2d(g_ch, inter_ch, 1)
#         self.W_x = nn.Conv2d(x_ch, inter_ch, 1)
#         self.psi = nn.Conv2d(inter_ch, 1, 1)
#         self.relu = nn.ReLU(inplace=True)
#         self.sigmoid = nn.Sigmoid()

#     def forward(self, g, x):
#         g1 = self.W_g(g)
#         x1 = self.W_x(x)
#         psi = self.relu(g1 + x1)
#         psi = self.sigmoid(self.psi(psi))
#         return x * psi

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads=4, ff_dim=512, dropout=0.1):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, embed_dim)
        )
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.drop = nn.Dropout(dropout)

    def forward(self, x):
        attn_out, _ = self.attn(x, x, x)
        x = self.norm1(x + self.drop(attn_out))
        ff_out = self.ff(x)
        x = self.norm2(x + self.drop(ff_out))
        return x

class TransformerBottleneck(nn.Module):
    def __init__(self, in_ch, patch_size=16, num_layers=2, num_heads=4):
        super().__init__()
        self.patch_size = patch_size
        self.embed_dim = in_ch

        self.flatten = nn.Unfold(kernel_size=patch_size, stride=patch_size)
        self.project = nn.Linear(patch_size * patch_size * in_ch, in_ch)

        self.transformer = nn.Sequential(*[
            TransformerBlock(embed_dim=in_ch, num_heads=num_heads) for _ in range(num_layers)
        ])

        self.reconstruct = nn.Linear(in_ch, patch_size * patch_size * in_ch)
        self.fold = None  # initialized at runtime

    def forward(self, x):
        B, C, H, W = x.shape
        patches = self.flatten(x).transpose(1, 2)  # (B, N, patch_dim)
        patches = self.project(patches)

        patches = self.transformer(patches)

        patches = self.reconstruct(patches).transpose(1, 2)
        if self.fold is None:
            self.fold = nn.Fold(output_size=(H, W),
                                kernel_size=self.patch_size,
                                stride=self.patch_size)
        x_reconstructed = self.fold(patches)
        return x_reconstructed


class TransUNet(nn.Module):
    def __init__(self, in_ch=1, out_ch=1, base_ch=64):
        super().__init__()
        # ================= Encoder =================
        self.enc1 = ConvBlock(in_ch, base_ch)          # 1 -> 64
        self.enc2 = ConvBlock(base_ch, base_ch*2)     # 64 -> 128
        self.enc3 = ConvBlock(base_ch*2, base_ch*4)   # 128 -> 256
        self.pool = nn.MaxPool2d(2)

        # ================= Bottleneck =================
        self.bottleneck = TransformerBottleneck(base_ch*4, patch_size=16, num_layers=2, num_heads=4)

        # ================= Decoder =================
        # dec3 input: bottleneck + skip e3 = 256 + 256 = 512 channels
        self.dec3 = ConvBlock(base_ch*4 + base_ch*4, base_ch*2)  # 512 -> 128
        # dec2 input: dec3 output + skip e2 = 128 + 128 = 256 channels
        self.dec2 = ConvBlock(base_ch*2 + base_ch*2, base_ch)    # 256 -> 64

        # final conv
        self.out_conv = nn.Conv2d(base_ch, out_ch, kernel_size=1)  # 64 -> 1

    def forward(self, x):
        # ---------------- Encoder ----------------
        e1 = self.enc1(x)                  # [B,64,H,W]
        e2 = self.enc2(self.pool(e1))      # [B,128,H/2,W/2]
        e3 = self.enc3(self.pool(e2))      # [B,256,H/4,W/4]

        # ---------------- Bottleneck ----------------
        b = self.bottleneck(self.pool(e3)) # [B,256,H/8,W/8]

        # ---------------- Decoder ----------------
        d3 = F.interpolate(b, size=e3.shape[2:], mode='bilinear', align_corners=False)
        d3 = self.dec3(torch.cat([d3, e3], dim=1))  # [B,128,H/4,W/4]

        d2 = F.interpolate(d3, size=e2.shape[2:], mode='bilinear', align_corners=False)
        d2 = self.dec2(torch.cat([d2, e2], dim=1))  # [B,64,H/2,W/2]

        d1 = F.interpolate(d2, size=e1.shape[2:], mode='bilinear', align_corners=False)
        out = self.out_conv(d1)                        # [B,1,H,W]

        return out




# ==========================
# TRAINING
# ==========================
model = TransUNet().to(DEVICE)


from pytorch_msssim import ssim  # pip install pytorch-msssim

alpha = 0.8
beta  = 0.2

criterion = nn.HuberLoss(delta=1.0)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

best_val_loss = float("inf")

for epoch in range(EPOCHS):
    model.train()
    train_loss = 0.0
    for x, y in train_loader:
        x, y = x.to(DEVICE), y.to(DEVICE)
        optimizer.zero_grad()
        preds = model(x)

        # Huber Loss
        huber_loss = criterion(preds, y)

        # SSIM Loss (1 - SSIM)
        ssim_loss = 1 - ssim(preds, y, data_range=1.0, size_average=True)  # assumes preds/y in [0,1]

        # Combined Loss
        loss = alpha * huber_loss + beta * ssim_loss
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * x.size(0)
    train_loss /= len(train_loader.dataset)

    # ---- Validation ----
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for x, y in val_loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            preds = model(x)

            huber_loss = criterion(preds, y)
            ssim_loss = 1 - ssim(preds, y, data_range=1.0, size_average=True)
            loss = alpha * huber_loss + beta * ssim_loss

            val_loss += loss.item() * x.size(0)
    val_loss /= len(val_loader.dataset)

    print(f"Epoch [{epoch+1}/{EPOCHS}] Train Loss: {train_loss:.6f}, Val Loss: {val_loss:.6f}")

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), MODEL_PATH)
        print(f"  ✅ Saved Best Model at Epoch {epoch+1}")


# ==========================
# INFERENCE
# ==========================
print("\nRunning inference on test set...")
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.eval()

for i, (x, y) in enumerate(test_loader):
    x = x.to(DEVICE)
    with torch.no_grad():
        pred = model(x).cpu().squeeze(0).squeeze(0).numpy()
    pred_img = (pred * 255.0).clip(0, 255).astype(np.uint8)
    Image.fromarray(pred_img).save(os.path.join(RESULTS_DIR, f"pred_{i+1}.png"))

print(f"Predictions saved in {RESULTS_DIR}")

Running on: cuda
Train: 2800, Val: 600, Test: 600
Epoch [1/75] Train Loss: 0.031924, Val Loss: 0.001859
  ✅ Saved Best Model at Epoch 1
Epoch [2/75] Train Loss: 0.001310, Val Loss: 0.001274
  ✅ Saved Best Model at Epoch 2
Epoch [3/75] Train Loss: 0.000900, Val Loss: 0.000790
  ✅ Saved Best Model at Epoch 3
Epoch [4/75] Train Loss: 0.000745, Val Loss: 0.000791
Epoch [5/75] Train Loss: 0.000688, Val Loss: 0.000721
  ✅ Saved Best Model at Epoch 5
Epoch [6/75] Train Loss: 0.000633, Val Loss: 0.000631
  ✅ Saved Best Model at Epoch 6
Epoch [7/75] Train Loss: 0.000595, Val Loss: 0.000581
  ✅ Saved Best Model at Epoch 7
Epoch [8/75] Train Loss: 0.000572, Val Loss: 0.000549
  ✅ Saved Best Model at Epoch 8
Epoch [9/75] Train Loss: 0.000540, Val Loss: 0.000657
Epoch [10/75] Train Loss: 0.000545, Val Loss: 0.000525
  ✅ Saved Best Model at Epoch 10
Epoch [11/75] Train Loss: 0.000497, Val Loss: 0.000532
Epoch [12/75] Train Loss: 0.000501, Val Loss: 0.000515
  ✅ Saved Best Model at Epoch 12
Epoch [13

In [2]:
import os
import cv2
import numpy as np
from skimage.metrics import structural_similarity as ssim
from math import log10
import shutil

# ========================== PATHS ==========================
RESULTS_DIR = r"C:\Preet\4000_paired_bscans\predictions_transunet_huber+ssim"  # predictions from TransUNet
GT_TEST_DIR = r"C:\Preet\4000_paired_bscans\ground_truth_test_transunet_huber+ssim"

# Copy ground truth test images to a dedicated folder
os.makedirs(GT_TEST_DIR, exist_ok=True)
for f in test_y:  # test_y comes from your train/val/test split
    shutil.copy(f, GT_TEST_DIR)

# ========================== FUNCTIONS ==========================
def calculate_psnr(img1, img2):
    mse = np.mean((img1 - img2) ** 2)
    if mse == 0:
        return float('inf')
    return 20 * log10(255.0 / np.sqrt(mse))

def calculate_ssim(img1, img2):
    return ssim(img1, img2, data_range=255)

# ========================== MAIN EVALUATION ==========================
psnr_values, ssim_values = [], []

pred_files = sorted([f for f in os.listdir(RESULTS_DIR) if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
gt_files   = sorted([f for f in os.listdir(GT_TEST_DIR) if f.lower().endswith(('.png', '.jpg', '.jpeg'))])

num_pairs = min(len(pred_files), len(gt_files))
if num_pairs == 0:
    print("[Error] No matching image files found in both directories.")
else:
    if len(pred_files) != len(gt_files):
        print(f"[Warning] Different number of images: Predictions={len(pred_files)}, Ground Truth={len(gt_files)}")
        print(f"Evaluating only first {num_pairs} matched pairs.")

    for i in range(num_pairs):
        pred_path = os.path.join(RESULTS_DIR, pred_files[i])
        gt_path   = os.path.join(GT_TEST_DIR, gt_files[i])

        pred_img = cv2.imread(pred_path, cv2.IMREAD_GRAYSCALE)
        gt_img   = cv2.imread(gt_path, cv2.IMREAD_GRAYSCALE)

        if pred_img is None or gt_img is None:
            print(f"[Error] Could not load: {pred_files[i]} or {gt_files[i]}")
            continue

        if pred_img.shape != gt_img.shape:
            pred_img = cv2.resize(pred_img, (gt_img.shape[1], gt_img.shape[0]))

        psnr_values.append(calculate_psnr(pred_img, gt_img))
        ssim_values.append(calculate_ssim(pred_img, gt_img))

    if psnr_values and ssim_values:
        print(f"\n---- Test Set Evaluation ----")
        print(f"SSIM: avg={np.mean(ssim_values):.4f}, min={np.min(ssim_values):.4f}, max={np.max(ssim_values):.4f}")
        print(f"PSNR: avg={np.mean(psnr_values):.2f} dB, min={np.min(psnr_values):.2f} dB, max={np.max(psnr_values):.2f} dB")



---- Test Set Evaluation ----
SSIM: avg=0.8915, min=0.7125, max=0.9974
PSNR: avg=35.40 dB, min=28.93 dB, max=50.90 dB


In [3]:
import os
import re
import numpy as np
from PIL import Image
import cv2
from math import log10
from skimage.metrics import structural_similarity as ssim

import torch
import torch.nn as nn
import torch.nn.functional as F

# =========================
# CONFIG
# =========================
FREQ_DIRS = [
    r"C:\Preet\validation dataset_4000\png_images_650M_1083M",
    r"C:\Preet\validation dataset_4000\png_images_700M_1167M",
    r"C:\Preet\validation dataset_4000\png_images_800M_1333M",
    r"C:\Preet\validation dataset_4000\png_images_850M_1416M",
    r"C:\Preet\validation dataset_4000\png_images_900M_1500M",
]

MODEL_PATH = r"C:\Preet\4000_paired_bscans\best_transunet_huber+ssim.pth"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMAGE_SIZE = (256, 256)

# =========================
# MODEL BLOCKS
# =========================
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True)
        )
    def forward(self, x): return self.conv(x)

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads=4, ff_dim=512, dropout=0.1):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, embed_dim)
        )
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.drop = nn.Dropout(dropout)
    def forward(self, x):
        attn_out, _ = self.attn(x, x, x)
        x = self.norm1(x + self.drop(attn_out))
        ff_out = self.ff(x)
        x = self.norm2(x + self.drop(ff_out))
        return x

class TransformerBottleneck(nn.Module):
    def __init__(self, in_ch, patch_size=16, num_layers=2, num_heads=4):
        super().__init__()
        self.patch_size = patch_size
        self.embed_dim = in_ch
        self.flatten = nn.Unfold(kernel_size=patch_size, stride=patch_size)
        self.project = nn.Linear(patch_size*patch_size*in_ch, in_ch)
        self.transformer = nn.Sequential(*[TransformerBlock(embed_dim=in_ch, num_heads=num_heads) for _ in range(num_layers)])
        self.reconstruct = nn.Linear(in_ch, patch_size*patch_size*in_ch)
        self.fold = None
    def forward(self, x):
        B,C,H,W = x.shape
        patches = self.flatten(x).transpose(1,2)
        patches = self.project(patches)
        patches = self.transformer(patches)
        patches = self.reconstruct(patches).transpose(1,2)
        if self.fold is None:
            self.fold = nn.Fold(output_size=(H,W), kernel_size=self.patch_size, stride=self.patch_size)
        x_reconstructed = self.fold(patches)
        return x_reconstructed

class TransUNet(nn.Module):
    def __init__(self, in_ch=1, out_ch=1, base_ch=64):
        super().__init__()
        self.enc1 = ConvBlock(in_ch, base_ch)
        self.enc2 = ConvBlock(base_ch, base_ch*2)
        self.enc3 = ConvBlock(base_ch*2, base_ch*4)
        self.pool = nn.MaxPool2d(2)
        self.bottleneck = TransformerBottleneck(base_ch*4, patch_size=16, num_layers=2, num_heads=4)
        self.dec3 = ConvBlock(base_ch*4 + base_ch*4, base_ch*2)
        self.dec2 = ConvBlock(base_ch*2 + base_ch*2, base_ch)
        self.out_conv = nn.Conv2d(base_ch, out_ch, kernel_size=1)
    def forward(self, x):
        e1 = self.enc1(x); e2 = self.enc2(self.pool(e1)); e3 = self.enc3(self.pool(e2))
        b = self.bottleneck(self.pool(e3))
        d3 = F.interpolate(b, size=e3.shape[2:], mode='bilinear', align_corners=False)
        d3 = self.dec3(torch.cat([d3, e3], dim=1))
        d2 = F.interpolate(d3, size=e2.shape[2:], mode='bilinear', align_corners=False)
        d2 = self.dec2(torch.cat([d2, e2], dim=1))
        d1 = F.interpolate(d2, size=e1.shape[2:], mode='bilinear', align_corners=False)
        return self.out_conv(d1)

# =========================
# HELPERS
# =========================
def calculate_psnr(img1, img2):
    mse = np.mean((img1 - img2) ** 2)
    if mse == 0: return float('inf')
    return 20 * log10(255.0 / np.sqrt(mse))

def calculate_ssim(img1, img2):
    return ssim(img1, img2, data_range=255)

def numerical_sort(files):
    return sorted(files, key=lambda f: int(re.search(r'\d+', f).group()))

# =========================
# LOAD MODEL
# =========================
model = TransUNet().to(DEVICE)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.eval()

# =========================
# INFERENCE LOOP
# =========================
for freq_dir in FREQ_DIRS:
    print(f"\n---- Evaluating {os.path.basename(freq_dir)} ----")
    pred_dir = os.path.join(freq_dir, "predictions_transunet_huber+ssim")
    os.makedirs(pred_dir, exist_ok=True)

    low_files = numerical_sort([f for f in os.listdir(freq_dir) if f.endswith("_l.png")])
    high_files = numerical_sort([f for f in os.listdir(freq_dir) if f.endswith("_h.png")])

    psnr_values, ssim_values = [], []

    for i in range(len(low_files)):
        low_path = os.path.join(freq_dir, low_files[i])
        high_path = os.path.join(freq_dir, high_files[i])

        # Load LR
        lr_img = np.array(Image.open(low_path).resize(IMAGE_SIZE), dtype=np.float32)/255.0
        lr_tensor = torch.tensor(lr_img).unsqueeze(0).unsqueeze(0).to(DEVICE)

        # Predict
        with torch.no_grad():
            pred = model(lr_tensor).cpu().numpy()

        # remove batch/channel dims and scale to 0-255
        pred_img = np.squeeze(pred)
        pred_img = (pred_img*255.0).clip(0,255).astype(np.uint8)
        pred_img = cv2.normalize(pred_img, None, 0, 255, cv2.NORM_MINMAX)

        # save prediction
        pred_path = os.path.join(pred_dir, f"pred_{i+1}.png")
        Image.fromarray(pred_img).save(pred_path)

        # Load GT
        gt_img = cv2.imread(high_path, cv2.IMREAD_GRAYSCALE)
        if pred_img.shape != gt_img.shape:
            pred_img = cv2.resize(pred_img, (gt_img.shape[1], gt_img.shape[0]))

        # Metrics
        psnr_values.append(calculate_psnr(pred_img, gt_img))
        ssim_values.append(calculate_ssim(pred_img, gt_img))

    if psnr_values and ssim_values:
        print(f"SSIM: avg={np.mean(ssim_values):.4f}, min={np.min(ssim_values):.4f}, max={np.max(ssim_values):.4f}")
        print(f"PSNR: avg={np.mean(psnr_values):.2f} dB, min={np.min(psnr_values):.2f} dB, max={np.max(psnr_values):.2f} dB")
    else:
        print("[Error] No valid pairs processed.")


---- Evaluating png_images_650M_1083M ----
SSIM: avg=0.9572, min=0.9274, max=0.9685
PSNR: avg=35.19 dB, min=32.11 dB, max=36.48 dB

---- Evaluating png_images_700M_1167M ----
SSIM: avg=0.9823, min=0.7541, max=0.9875
PSNR: avg=36.57 dB, min=29.17 dB, max=38.28 dB

---- Evaluating png_images_800M_1333M ----
SSIM: avg=0.9893, min=0.9718, max=0.9930
PSNR: avg=37.24 dB, min=33.45 dB, max=38.14 dB

---- Evaluating png_images_850M_1416M ----
SSIM: avg=0.9649, min=0.9314, max=0.9759
PSNR: avg=33.40 dB, min=31.36 dB, max=35.17 dB

---- Evaluating png_images_900M_1500M ----
SSIM: avg=0.9378, min=0.8767, max=0.9562
PSNR: avg=33.12 dB, min=30.56 dB, max=33.89 dB


In [4]:
import os
import re
import numpy as np
from PIL import Image
import cv2
from math import log10
from skimage.metrics import structural_similarity as ssim

import torch
import torch.nn as nn
import torch.nn.functional as F

# =========================
# CONFIG
# =========================
DATASET_DIR = r"C:\Preet\validation dataset_4000\png_images_650M_1083M_1800M"
MODEL_PATH = r"C:\Preet\4000_paired_bscans\best_transunet_huber+ssim.pth"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMAGE_SIZE = (256, 256)

# =========================
# MODEL BLOCKS
# =========================
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True)
        )
    def forward(self, x): return self.conv(x)

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads=4, ff_dim=512, dropout=0.1):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, embed_dim)
        )
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.drop = nn.Dropout(dropout)
    def forward(self, x):
        attn_out, _ = self.attn(x, x, x)
        x = self.norm1(x + self.drop(attn_out))
        ff_out = self.ff(x)
        x = self.norm2(x + self.drop(ff_out))
        return x

class TransformerBottleneck(nn.Module):
    def __init__(self, in_ch, patch_size=16, num_layers=2, num_heads=4):
        super().__init__()
        self.patch_size = patch_size
        self.embed_dim = in_ch
        self.flatten = nn.Unfold(kernel_size=patch_size, stride=patch_size)
        self.project = nn.Linear(patch_size*patch_size*in_ch, in_ch)
        self.transformer = nn.Sequential(*[TransformerBlock(embed_dim=in_ch, num_heads=num_heads) for _ in range(num_layers)])
        self.reconstruct = nn.Linear(in_ch, patch_size*patch_size*in_ch)
        self.fold = None
    def forward(self, x):
        B,C,H,W = x.shape
        patches = self.flatten(x).transpose(1,2)
        patches = self.project(patches)
        patches = self.transformer(patches)
        patches = self.reconstruct(patches).transpose(1,2)
        if self.fold is None:
            self.fold = nn.Fold(output_size=(H,W), kernel_size=self.patch_size, stride=self.patch_size)
        x_reconstructed = self.fold(patches)
        return x_reconstructed

class TransUNet(nn.Module):
    def __init__(self, in_ch=1, out_ch=1, base_ch=64):
        super().__init__()
        self.enc1 = ConvBlock(in_ch, base_ch)
        self.enc2 = ConvBlock(base_ch, base_ch*2)
        self.enc3 = ConvBlock(base_ch*2, base_ch*4)
        self.pool = nn.MaxPool2d(2)
        self.bottleneck = TransformerBottleneck(base_ch*4, patch_size=16, num_layers=2, num_heads=4)
        self.dec3 = ConvBlock(base_ch*4 + base_ch*4, base_ch*2)
        self.dec2 = ConvBlock(base_ch*2 + base_ch*2, base_ch)
        self.out_conv = nn.Conv2d(base_ch, out_ch, kernel_size=1)
    def forward(self, x):
        e1 = self.enc1(x); e2 = self.enc2(self.pool(e1)); e3 = self.enc3(self.pool(e2))
        b = self.bottleneck(self.pool(e3))
        d3 = F.interpolate(b, size=e3.shape[2:], mode='bilinear', align_corners=False)
        d3 = self.dec3(torch.cat([d3, e3], dim=1))
        d2 = F.interpolate(d3, size=e2.shape[2:], mode='bilinear', align_corners=False)
        d2 = self.dec2(torch.cat([d2, e2], dim=1))
        d1 = F.interpolate(d2, size=e1.shape[2:], mode='bilinear', align_corners=False)
        return self.out_conv(d1)

# =========================
# HELPERS
# =========================
def calculate_psnr(img1, img2):
    mse = np.mean((img1 - img2) ** 2)
    if mse == 0: return float('inf')
    return 20 * log10(255.0 / np.sqrt(mse))

def calculate_ssim(img1, img2):
    return ssim(img1, img2, data_range=255)

def numerical_sort(files):
    return sorted(files, key=lambda f: int(re.search(r'\d+', f).group()))

def run_model(img, model):
    tensor = torch.tensor(img, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        pred = model(tensor).cpu().numpy()
    pred = np.squeeze(pred)
    pred = (pred * 255.0).clip(0, 255).astype(np.uint8)
    return cv2.normalize(pred, None, 0, 255, cv2.NORM_MINMAX)

# =========================
# LOAD MODEL
# =========================
model = TransUNet().to(DEVICE)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.eval()

# =========================
# INFERENCE 2-STAGE
# =========================
files_650 = numerical_sort([f for f in os.listdir(DATASET_DIR) if f.endswith("650_bscan.png")])
files_1083 = numerical_sort([f for f in os.listdir(DATASET_DIR) if f.endswith("1083_bscan.png")])
files_1800 = numerical_sort([f for f in os.listdir(DATASET_DIR) if f.endswith("1800_bscan.png")])

psnr_stage1, ssim_stage1 = [], []
psnr_stage2, ssim_stage2 = [], []

pred_dir1 = os.path.join(DATASET_DIR, "predictions_stage1_transunet_huber+ssim")
pred_dir2 = os.path.join(DATASET_DIR, "predictions_stage2_transunet_huber+ssim")
os.makedirs(pred_dir1, exist_ok=True)
os.makedirs(pred_dir2, exist_ok=True)

for i in range(len(files_650)):
    # ---- Stage 1: 650 -> pred -> compare with 1083 ----
    lr_img = np.array(Image.open(os.path.join(DATASET_DIR, files_650[i])).resize(IMAGE_SIZE), dtype=np.float32) / 255.0
    gt_1083 = cv2.imread(os.path.join(DATASET_DIR, files_1083[i]), cv2.IMREAD_GRAYSCALE)

    pred_1083 = run_model(lr_img, model)
    Image.fromarray(pred_1083).save(os.path.join(pred_dir1, f"pred1_{i+1}.png"))

    if pred_1083.shape != gt_1083.shape:
        pred_1083 = cv2.resize(pred_1083, (gt_1083.shape[1], gt_1083.shape[0]))

    psnr_stage1.append(calculate_psnr(pred_1083, gt_1083))
    ssim_stage1.append(calculate_ssim(pred_1083, gt_1083))

    # ---- Stage 2: pred_1083 -> pred -> compare with 1800 ----
    gt_1800 = cv2.imread(os.path.join(DATASET_DIR, files_1800[i]), cv2.IMREAD_GRAYSCALE)
    pred_1083_resized = cv2.resize(pred_1083, IMAGE_SIZE).astype(np.float32)/255.0
    pred_1800 = run_model(pred_1083_resized, model)
    Image.fromarray(pred_1800).save(os.path.join(pred_dir2, f"pred2_{i+1}.png"))

    if pred_1800.shape != gt_1800.shape:
        pred_1800 = cv2.resize(pred_1800, (gt_1800.shape[1], gt_1800.shape[0]))

    psnr_stage2.append(calculate_psnr(pred_1800, gt_1800))
    ssim_stage2.append(calculate_ssim(pred_1800, gt_1800))

# =========================
# RESULTS
# =========================
print("\n---- Stage 1: 650 → pred → compare with 1083 ----")
print(f"SSIM: avg={np.mean(ssim_stage1):.4f}, min={np.min(ssim_stage1):.4f}, max={np.max(ssim_stage1):.4f}")
print(f"PSNR: avg={np.mean(psnr_stage1):.2f} dB, min={np.min(psnr_stage1):.2f} dB, max={np.max(psnr_stage1):.2f} dB")

print("\n---- Stage 2: pred_1083 → pred → compare with 1800 ----")
print(f"SSIM: avg={np.mean(ssim_stage2):.4f}, min={np.min(ssim_stage2):.4f}, max={np.max(ssim_stage2):.4f}")
print(f"PSNR: avg={np.mean(psnr_stage2):.2f} dB, min={np.min(psnr_stage2):.2f} dB, max={np.max(psnr_stage2):.2f} dB")



---- Stage 1: 650 → pred → compare with 1083 ----
SSIM: avg=0.9552, min=0.9251, max=0.9680
PSNR: avg=35.14 dB, min=31.82 dB, max=36.37 dB

---- Stage 2: pred_1083 → pred → compare with 1800 ----
SSIM: avg=0.9244, min=0.8801, max=0.9424
PSNR: avg=28.53 dB, min=27.84 dB, max=31.16 dB


In [5]:
import os
import re
import numpy as np
from PIL import Image
import cv2
from math import log10
from skimage.metrics import structural_similarity as ssim

import torch
import torch.nn as nn
import torch.nn.functional as F

# =========================
# CONFIG
# =========================
DATASET_DIR = r"C:\Preet\validation dataset_4000\png_images_900M_1500M_2500M"
MODEL_PATH = r"C:\Preet\4000_paired_bscans\best_transunet_huber+ssim.pth"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMAGE_SIZE = (256, 256)

# =========================
# MODEL BLOCKS
# =========================
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True)
        )
    def forward(self, x): return self.conv(x)

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads=4, ff_dim=512, dropout=0.1):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, embed_dim)
        )
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.drop = nn.Dropout(dropout)
    def forward(self, x):
        attn_out, _ = self.attn(x, x, x)
        x = self.norm1(x + self.drop(attn_out))
        ff_out = self.ff(x)
        x = self.norm2(x + self.drop(ff_out))
        return x

class TransformerBottleneck(nn.Module):
    def __init__(self, in_ch, patch_size=16, num_layers=2, num_heads=4):
        super().__init__()
        self.patch_size = patch_size
        self.embed_dim = in_ch
        self.flatten = nn.Unfold(kernel_size=patch_size, stride=patch_size)
        self.project = nn.Linear(patch_size*patch_size*in_ch, in_ch)
        self.transformer = nn.Sequential(*[TransformerBlock(embed_dim=in_ch, num_heads=num_heads) for _ in range(num_layers)])
        self.reconstruct = nn.Linear(in_ch, patch_size*patch_size*in_ch)
        self.fold = None
    def forward(self, x):
        B,C,H,W = x.shape
        patches = self.flatten(x).transpose(1,2)
        patches = self.project(patches)
        patches = self.transformer(patches)
        patches = self.reconstruct(patches).transpose(1,2)
        if self.fold is None:
            self.fold = nn.Fold(output_size=(H,W), kernel_size=self.patch_size, stride=self.patch_size)
        x_reconstructed = self.fold(patches)
        return x_reconstructed

class TransUNet(nn.Module):
    def __init__(self, in_ch=1, out_ch=1, base_ch=64):
        super().__init__()
        self.enc1 = ConvBlock(in_ch, base_ch)
        self.enc2 = ConvBlock(base_ch, base_ch*2)
        self.enc3 = ConvBlock(base_ch*2, base_ch*4)
        self.pool = nn.MaxPool2d(2)
        self.bottleneck = TransformerBottleneck(base_ch*4, patch_size=16, num_layers=2, num_heads=4)
        self.dec3 = ConvBlock(base_ch*4 + base_ch*4, base_ch*2)
        self.dec2 = ConvBlock(base_ch*2 + base_ch*2, base_ch)
        self.out_conv = nn.Conv2d(base_ch, out_ch, kernel_size=1)
    def forward(self, x):
        e1 = self.enc1(x); e2 = self.enc2(self.pool(e1)); e3 = self.enc3(self.pool(e2))
        b = self.bottleneck(self.pool(e3))
        d3 = F.interpolate(b, size=e3.shape[2:], mode='bilinear', align_corners=False)
        d3 = self.dec3(torch.cat([d3, e3], dim=1))
        d2 = F.interpolate(d3, size=e2.shape[2:], mode='bilinear', align_corners=False)
        d2 = self.dec2(torch.cat([d2, e2], dim=1))
        d1 = F.interpolate(d2, size=e1.shape[2:], mode='bilinear', align_corners=False)
        return self.out_conv(d1)

# =========================
# HELPERS
# =========================
def calculate_psnr(img1, img2):
    mse = np.mean((img1 - img2) ** 2)
    if mse == 0: return float('inf')
    return 20 * log10(255.0 / np.sqrt(mse))

def calculate_ssim(img1, img2):
    return ssim(img1, img2, data_range=255)

def numerical_sort(files):
    return sorted(files, key=lambda f: int(re.search(r'\d+', f).group()))

def run_model(img, model):
    tensor = torch.tensor(img, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        pred = model(tensor).cpu().numpy()
    pred = np.squeeze(pred)
    pred = (pred * 255.0).clip(0, 255).astype(np.uint8)
    return cv2.normalize(pred, None, 0, 255, cv2.NORM_MINMAX)

# =========================
# LOAD MODEL
# =========================
model = TransUNet().to(DEVICE)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.eval()

# =========================
# INFERENCE 2-STAGE
# =========================
files_650 = numerical_sort([f for f in os.listdir(DATASET_DIR) if f.endswith("900_bscan.png")])
files_1083 = numerical_sort([f for f in os.listdir(DATASET_DIR) if f.endswith("1500_bscan.png")])
files_1800 = numerical_sort([f for f in os.listdir(DATASET_DIR) if f.endswith("2500_bscan.png")])

psnr_stage1, ssim_stage1 = [], []
psnr_stage2, ssim_stage2 = [], []

pred_dir1 = os.path.join(DATASET_DIR, "predictions_stage1_transunet_huber+ssim")
pred_dir2 = os.path.join(DATASET_DIR, "predictions_stage2_transunet_huber+ssim")
os.makedirs(pred_dir1, exist_ok=True)
os.makedirs(pred_dir2, exist_ok=True)

for i in range(len(files_650)):
    # ---- Stage 1: 650 -> pred -> compare with 1083 ----
    lr_img = np.array(Image.open(os.path.join(DATASET_DIR, files_650[i])).resize(IMAGE_SIZE), dtype=np.float32) / 255.0
    gt_1083 = cv2.imread(os.path.join(DATASET_DIR, files_1083[i]), cv2.IMREAD_GRAYSCALE)

    pred_1083 = run_model(lr_img, model)
    Image.fromarray(pred_1083).save(os.path.join(pred_dir1, f"pred1_{i+1}.png"))

    if pred_1083.shape != gt_1083.shape:
        pred_1083 = cv2.resize(pred_1083, (gt_1083.shape[1], gt_1083.shape[0]))

    psnr_stage1.append(calculate_psnr(pred_1083, gt_1083))
    ssim_stage1.append(calculate_ssim(pred_1083, gt_1083))

    # ---- Stage 2: pred_1083 -> pred -> compare with 1800 ----
    gt_1800 = cv2.imread(os.path.join(DATASET_DIR, files_1800[i]), cv2.IMREAD_GRAYSCALE)
    pred_1083_resized = cv2.resize(pred_1083, IMAGE_SIZE).astype(np.float32)/255.0
    pred_1800 = run_model(pred_1083_resized, model)
    Image.fromarray(pred_1800).save(os.path.join(pred_dir2, f"pred2_{i+1}.png"))

    if pred_1800.shape != gt_1800.shape:
        pred_1800 = cv2.resize(pred_1800, (gt_1800.shape[1], gt_1800.shape[0]))

    psnr_stage2.append(calculate_psnr(pred_1800, gt_1800))
    ssim_stage2.append(calculate_ssim(pred_1800, gt_1800))

# =========================
# RESULTS
# =========================
print("\n---- Stage 1: 900 → pred → compare with 1500 ----")
print(f"SSIM: avg={np.mean(ssim_stage1):.4f}, min={np.min(ssim_stage1):.4f}, max={np.max(ssim_stage1):.4f}")
print(f"PSNR: avg={np.mean(psnr_stage1):.2f} dB, min={np.min(psnr_stage1):.2f} dB, max={np.max(psnr_stage1):.2f} dB")

print("\n---- Stage 2: pred_1500 → pred → compare with 2500 ----")
print(f"SSIM: avg={np.mean(ssim_stage2):.4f}, min={np.min(ssim_stage2):.4f}, max={np.max(ssim_stage2):.4f}")
print(f"PSNR: avg={np.mean(psnr_stage2):.2f} dB, min={np.min(psnr_stage2):.2f} dB, max={np.max(psnr_stage2):.2f} dB")


---- Stage 1: 900 → pred → compare with 1500 ----
SSIM: avg=0.7291, min=0.5953, max=0.8614
PSNR: avg=29.81 dB, min=28.60 dB, max=30.28 dB

---- Stage 2: pred_1500 → pred → compare with 2500 ----
SSIM: avg=0.7513, min=0.6396, max=0.8559
PSNR: avg=27.60 dB, min=27.30 dB, max=28.15 dB


--------------------

-------------------

--------------------

-----------------------

clean dataset

In [1]:
import os
import cv2
import numpy as np
from PIL import Image
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

# ==========================
# CONFIG
# ==========================
DATASET_DIR = r"C:\Preet\clean_paired_bscans"  # *_l.png and *_h.png paired images
IMAGE_SIZE = (256, 256)
BATCH_SIZE = 8
EPOCHS = 75
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Running on:", DEVICE)

MODEL_PATH = os.path.join(DATASET_DIR, "best_transunet_huber+ssim.pth")
RESULTS_DIR = os.path.join(DATASET_DIR, "predictions_transunet_huber+ssim")
os.makedirs(RESULTS_DIR, exist_ok=True)

# ==========================
# DATASET
# ==========================
class GPRDataset(Dataset):
    def __init__(self, x_paths, y_paths):
        self.x_paths = x_paths
        self.y_paths = y_paths

    def __len__(self):
        return len(self.x_paths)

    def __getitem__(self, idx):
        x = np.array(Image.open(self.x_paths[idx]).resize(IMAGE_SIZE), dtype=np.float32) / 255.0
        y = np.array(Image.open(self.y_paths[idx]).resize(IMAGE_SIZE), dtype=np.float32) / 255.0
        x = torch.tensor(x).unsqueeze(0)  # (1,H,W)
        y = torch.tensor(y).unsqueeze(0)
        return x, y

def load_data(dataset_dir):
    low_paths, high_paths = [], []
    for file in os.listdir(dataset_dir):
        if file.endswith("_l.png"):
            low_path = os.path.join(dataset_dir, file)
            high_path = os.path.join(dataset_dir, file.replace("_l.png", "_h.png"))
            if os.path.exists(high_path):
                low_paths.append(low_path)
                high_paths.append(high_path)
    return low_paths, high_paths

all_x, all_y = load_data(DATASET_DIR)

# Split 70/15/15
train_x, temp_x, train_y, temp_y = train_test_split(all_x, all_y, test_size=0.30, random_state=42)
val_x, test_x, val_y, test_y = train_test_split(temp_x, temp_y, test_size=0.50, random_state=42)

print(f"Train: {len(train_x)}, Val: {len(val_x)}, Test: {len(test_x)}")

train_loader = DataLoader(GPRDataset(train_x, train_y), batch_size=BATCH_SIZE, shuffle=True)
val_loader   = DataLoader(GPRDataset(val_x, val_y), batch_size=BATCH_SIZE, shuffle=False)
test_loader  = DataLoader(GPRDataset(test_x, test_y), batch_size=1, shuffle=False)

# ==========================
# MODEL BLOCKS
# ==========================
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.conv(x)

# class AttentionGate(nn.Module):
#     def __init__(self, g_ch, x_ch, inter_ch):
#         super().__init__()
#         self.W_g = nn.Conv2d(g_ch, inter_ch, 1)
#         self.W_x = nn.Conv2d(x_ch, inter_ch, 1)
#         self.psi = nn.Conv2d(inter_ch, 1, 1)
#         self.relu = nn.ReLU(inplace=True)
#         self.sigmoid = nn.Sigmoid()

#     def forward(self, g, x):
#         g1 = self.W_g(g)
#         x1 = self.W_x(x)
#         psi = self.relu(g1 + x1)
#         psi = self.sigmoid(self.psi(psi))
#         return x * psi

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads=4, ff_dim=512, dropout=0.1):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, embed_dim)
        )
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.drop = nn.Dropout(dropout)

    def forward(self, x):
        attn_out, _ = self.attn(x, x, x)
        x = self.norm1(x + self.drop(attn_out))
        ff_out = self.ff(x)
        x = self.norm2(x + self.drop(ff_out))
        return x

class TransformerBottleneck(nn.Module):
    def __init__(self, in_ch, patch_size=16, num_layers=2, num_heads=4):
        super().__init__()
        self.patch_size = patch_size
        self.embed_dim = in_ch

        self.flatten = nn.Unfold(kernel_size=patch_size, stride=patch_size)
        self.project = nn.Linear(patch_size * patch_size * in_ch, in_ch)

        self.transformer = nn.Sequential(*[
            TransformerBlock(embed_dim=in_ch, num_heads=num_heads) for _ in range(num_layers)
        ])

        self.reconstruct = nn.Linear(in_ch, patch_size * patch_size * in_ch)
        self.fold = None  # initialized at runtime

    def forward(self, x):
        B, C, H, W = x.shape
        patches = self.flatten(x).transpose(1, 2)  # (B, N, patch_dim)
        patches = self.project(patches)

        patches = self.transformer(patches)

        patches = self.reconstruct(patches).transpose(1, 2)
        if self.fold is None:
            self.fold = nn.Fold(output_size=(H, W),
                                kernel_size=self.patch_size,
                                stride=self.patch_size)
        x_reconstructed = self.fold(patches)
        return x_reconstructed


class TransUNet(nn.Module):
    def __init__(self, in_ch=1, out_ch=1, base_ch=64):
        super().__init__()
        # ================= Encoder =================
        self.enc1 = ConvBlock(in_ch, base_ch)          # 1 -> 64
        self.enc2 = ConvBlock(base_ch, base_ch*2)     # 64 -> 128
        self.enc3 = ConvBlock(base_ch*2, base_ch*4)   # 128 -> 256
        self.pool = nn.MaxPool2d(2)

        # ================= Bottleneck =================
        self.bottleneck = TransformerBottleneck(base_ch*4, patch_size=16, num_layers=2, num_heads=4)

        # ================= Decoder =================
        # dec3 input: bottleneck + skip e3 = 256 + 256 = 512 channels
        self.dec3 = ConvBlock(base_ch*4 + base_ch*4, base_ch*2)  # 512 -> 128
        # dec2 input: dec3 output + skip e2 = 128 + 128 = 256 channels
        self.dec2 = ConvBlock(base_ch*2 + base_ch*2, base_ch)    # 256 -> 64

        # final conv
        self.out_conv = nn.Conv2d(base_ch, out_ch, kernel_size=1)  # 64 -> 1

    def forward(self, x):
        # ---------------- Encoder ----------------
        e1 = self.enc1(x)                  # [B,64,H,W]
        e2 = self.enc2(self.pool(e1))      # [B,128,H/2,W/2]
        e3 = self.enc3(self.pool(e2))      # [B,256,H/4,W/4]

        # ---------------- Bottleneck ----------------
        b = self.bottleneck(self.pool(e3)) # [B,256,H/8,W/8]

        # ---------------- Decoder ----------------
        d3 = F.interpolate(b, size=e3.shape[2:], mode='bilinear', align_corners=False)
        d3 = self.dec3(torch.cat([d3, e3], dim=1))  # [B,128,H/4,W/4]

        d2 = F.interpolate(d3, size=e2.shape[2:], mode='bilinear', align_corners=False)
        d2 = self.dec2(torch.cat([d2, e2], dim=1))  # [B,64,H/2,W/2]

        d1 = F.interpolate(d2, size=e1.shape[2:], mode='bilinear', align_corners=False)
        out = self.out_conv(d1)                        # [B,1,H,W]

        return out




# ==========================
# TRAINING
# ==========================
model = TransUNet().to(DEVICE)


from pytorch_msssim import ssim  # pip install pytorch-msssim

alpha = 0.8
beta  = 0.2

criterion = nn.HuberLoss(delta=1.0)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

best_val_loss = float("inf")

for epoch in range(EPOCHS):
    model.train()
    train_loss = 0.0
    for x, y in train_loader:
        x, y = x.to(DEVICE), y.to(DEVICE)
        optimizer.zero_grad()
        preds = model(x)

        # Huber Loss
        huber_loss = criterion(preds, y)

        # SSIM Loss (1 - SSIM)
        ssim_loss = 1 - ssim(preds, y, data_range=1.0, size_average=True)  # assumes preds/y in [0,1]

        # Combined Loss
        loss = alpha * huber_loss + beta * ssim_loss
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * x.size(0)
    train_loss /= len(train_loader.dataset)

    # ---- Validation ----
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for x, y in val_loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            preds = model(x)

            huber_loss = criterion(preds, y)
            ssim_loss = 1 - ssim(preds, y, data_range=1.0, size_average=True)
            loss = alpha * huber_loss + beta * ssim_loss

            val_loss += loss.item() * x.size(0)
    val_loss /= len(val_loader.dataset)

    print(f"Epoch [{epoch+1}/{EPOCHS}] Train Loss: {train_loss:.6f}, Val Loss: {val_loss:.6f}")

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), MODEL_PATH)
        print(f"  ✅ Saved Best Model at Epoch {epoch+1}")


# ==========================
# INFERENCE
# ==========================
print("\nRunning inference on test set...")
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.eval()

for i, (x, y) in enumerate(test_loader):
    x = x.to(DEVICE)
    with torch.no_grad():
        pred = model(x).cpu().squeeze(0).squeeze(0).numpy()
    pred_img = (pred * 255.0).clip(0, 255).astype(np.uint8)
    Image.fromarray(pred_img).save(os.path.join(RESULTS_DIR, f"pred_{i+1}.png"))

print(f"Predictions saved in {RESULTS_DIR}")

Running on: cuda
Train: 5810, Val: 1245, Test: 1245
Epoch [1/75] Train Loss: 0.009544, Val Loss: 0.000565
  ✅ Saved Best Model at Epoch 1
Epoch [2/75] Train Loss: 0.000520, Val Loss: 0.000546
  ✅ Saved Best Model at Epoch 2
Epoch [3/75] Train Loss: 0.000439, Val Loss: 0.000497
  ✅ Saved Best Model at Epoch 3
Epoch [4/75] Train Loss: 0.000388, Val Loss: 0.000361
  ✅ Saved Best Model at Epoch 4
Epoch [5/75] Train Loss: 0.000362, Val Loss: 0.000345
  ✅ Saved Best Model at Epoch 5
Epoch [6/75] Train Loss: 0.000344, Val Loss: 0.000358
Epoch [7/75] Train Loss: 0.000331, Val Loss: 0.000314
  ✅ Saved Best Model at Epoch 7
Epoch [8/75] Train Loss: 0.000325, Val Loss: 0.000305
  ✅ Saved Best Model at Epoch 8
Epoch [9/75] Train Loss: 0.000315, Val Loss: 0.000301
  ✅ Saved Best Model at Epoch 9
Epoch [10/75] Train Loss: 0.000306, Val Loss: 0.000298
  ✅ Saved Best Model at Epoch 10
Epoch [11/75] Train Loss: 0.000304, Val Loss: 0.000294
  ✅ Saved Best Model at Epoch 11
Epoch [12/75] Train Loss: 0.00

In [2]:
import os
import cv2
import numpy as np
from skimage.metrics import structural_similarity as ssim
from math import log10
import shutil

# ========================== PATHS ==========================
RESULTS_DIR = r"C:\Preet\clean_paired_bscans\predictions_transunet_huber+ssim"  # predictions from TransUNet
GT_TEST_DIR = r"C:\Preet\clean_paired_bscans\ground_truth_test_transunet_huber+ssim"

# Copy ground truth test images to a dedicated folder
os.makedirs(GT_TEST_DIR, exist_ok=True)
for f in test_y:  # test_y comes from your train/val/test split
    shutil.copy(f, GT_TEST_DIR)

# ========================== FUNCTIONS ==========================
def calculate_psnr(img1, img2):
    mse = np.mean((img1 - img2) ** 2)
    if mse == 0:
        return float('inf')
    return 20 * log10(255.0 / np.sqrt(mse))

def calculate_ssim(img1, img2):
    return ssim(img1, img2, data_range=255)

# ========================== MAIN EVALUATION ==========================
psnr_values, ssim_values = [], []

pred_files = sorted([f for f in os.listdir(RESULTS_DIR) if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
gt_files   = sorted([f for f in os.listdir(GT_TEST_DIR) if f.lower().endswith(('.png', '.jpg', '.jpeg'))])

num_pairs = min(len(pred_files), len(gt_files))
if num_pairs == 0:
    print("[Error] No matching image files found in both directories.")
else:
    if len(pred_files) != len(gt_files):
        print(f"[Warning] Different number of images: Predictions={len(pred_files)}, Ground Truth={len(gt_files)}")
        print(f"Evaluating only first {num_pairs} matched pairs.")

    for i in range(num_pairs):
        pred_path = os.path.join(RESULTS_DIR, pred_files[i])
        gt_path   = os.path.join(GT_TEST_DIR, gt_files[i])

        pred_img = cv2.imread(pred_path, cv2.IMREAD_GRAYSCALE)
        gt_img   = cv2.imread(gt_path, cv2.IMREAD_GRAYSCALE)

        if pred_img is None or gt_img is None:
            print(f"[Error] Could not load: {pred_files[i]} or {gt_files[i]}")
            continue

        if pred_img.shape != gt_img.shape:
            pred_img = cv2.resize(pred_img, (gt_img.shape[1], gt_img.shape[0]))

        psnr_values.append(calculate_psnr(pred_img, gt_img))
        ssim_values.append(calculate_ssim(pred_img, gt_img))

    if psnr_values and ssim_values:
        print(f"\n---- Test Set Evaluation ----")
        print(f"SSIM: avg={np.mean(ssim_values):.4f}, min={np.min(ssim_values):.4f}, max={np.max(ssim_values):.4f}")
        print(f"PSNR: avg={np.mean(psnr_values):.2f} dB, min={np.min(psnr_values):.2f} dB, max={np.max(psnr_values):.2f} dB")


---- Test Set Evaluation ----
SSIM: avg=0.9044, min=0.7484, max=0.9979
PSNR: avg=35.99 dB, min=29.61 dB, max=52.27 dB


In [3]:
import os
import re
import numpy as np
from PIL import Image
import cv2
from math import log10
from skimage.metrics import structural_similarity as ssim

import torch
import torch.nn as nn
import torch.nn.functional as F

# =========================
# CONFIG
# =========================
FREQ_DIRS = [
    r"C:\Preet\clean_validation dataset\png_images_650M_1083M",
    r"C:\Preet\clean_validation dataset\png_images_700M_1167M",
    r"C:\Preet\clean_validation dataset\png_images_800M_1333M",
    r"C:\Preet\clean_validation dataset\png_images_850M_1416M",
    r"C:\Preet\clean_validation dataset\png_images_900M_1500M",
]

MODEL_PATH = r"C:\Preet\clean_paired_bscans\best_transunet_huber+ssim.pth"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMAGE_SIZE = (256, 256)

# =========================
# MODEL BLOCKS
# =========================
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True)
        )
    def forward(self, x): return self.conv(x)

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads=4, ff_dim=512, dropout=0.1):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, embed_dim)
        )
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.drop = nn.Dropout(dropout)
    def forward(self, x):
        attn_out, _ = self.attn(x, x, x)
        x = self.norm1(x + self.drop(attn_out))
        ff_out = self.ff(x)
        x = self.norm2(x + self.drop(ff_out))
        return x

class TransformerBottleneck(nn.Module):
    def __init__(self, in_ch, patch_size=16, num_layers=2, num_heads=4):
        super().__init__()
        self.patch_size = patch_size
        self.embed_dim = in_ch
        self.flatten = nn.Unfold(kernel_size=patch_size, stride=patch_size)
        self.project = nn.Linear(patch_size*patch_size*in_ch, in_ch)
        self.transformer = nn.Sequential(*[TransformerBlock(embed_dim=in_ch, num_heads=num_heads) for _ in range(num_layers)])
        self.reconstruct = nn.Linear(in_ch, patch_size*patch_size*in_ch)
        self.fold = None
    def forward(self, x):
        B,C,H,W = x.shape
        patches = self.flatten(x).transpose(1,2)
        patches = self.project(patches)
        patches = self.transformer(patches)
        patches = self.reconstruct(patches).transpose(1,2)
        if self.fold is None:
            self.fold = nn.Fold(output_size=(H,W), kernel_size=self.patch_size, stride=self.patch_size)
        x_reconstructed = self.fold(patches)
        return x_reconstructed

class TransUNet(nn.Module):
    def __init__(self, in_ch=1, out_ch=1, base_ch=64):
        super().__init__()
        self.enc1 = ConvBlock(in_ch, base_ch)
        self.enc2 = ConvBlock(base_ch, base_ch*2)
        self.enc3 = ConvBlock(base_ch*2, base_ch*4)
        self.pool = nn.MaxPool2d(2)
        self.bottleneck = TransformerBottleneck(base_ch*4, patch_size=16, num_layers=2, num_heads=4)
        self.dec3 = ConvBlock(base_ch*4 + base_ch*4, base_ch*2)
        self.dec2 = ConvBlock(base_ch*2 + base_ch*2, base_ch)
        self.out_conv = nn.Conv2d(base_ch, out_ch, kernel_size=1)
    def forward(self, x):
        e1 = self.enc1(x); e2 = self.enc2(self.pool(e1)); e3 = self.enc3(self.pool(e2))
        b = self.bottleneck(self.pool(e3))
        d3 = F.interpolate(b, size=e3.shape[2:], mode='bilinear', align_corners=False)
        d3 = self.dec3(torch.cat([d3, e3], dim=1))
        d2 = F.interpolate(d3, size=e2.shape[2:], mode='bilinear', align_corners=False)
        d2 = self.dec2(torch.cat([d2, e2], dim=1))
        d1 = F.interpolate(d2, size=e1.shape[2:], mode='bilinear', align_corners=False)
        return self.out_conv(d1)

# =========================
# HELPERS
# =========================
def calculate_psnr(img1, img2):
    mse = np.mean((img1 - img2) ** 2)
    if mse == 0: return float('inf')
    return 20 * log10(255.0 / np.sqrt(mse))

def calculate_ssim(img1, img2):
    return ssim(img1, img2, data_range=255)

def numerical_sort(files):
    return sorted(files, key=lambda f: int(re.search(r'\d+', f).group()))

# =========================
# LOAD MODEL
# =========================
model = TransUNet().to(DEVICE)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.eval()

# =========================
# INFERENCE LOOP
# =========================
for freq_dir in FREQ_DIRS:
    print(f"\n---- Evaluating {os.path.basename(freq_dir)} ----")
    pred_dir = os.path.join(freq_dir, "predictions_transunet_huber+ssim")
    os.makedirs(pred_dir, exist_ok=True)

    low_files = numerical_sort([f for f in os.listdir(freq_dir) if f.endswith("_l.png")])
    high_files = numerical_sort([f for f in os.listdir(freq_dir) if f.endswith("_h.png")])

    psnr_values, ssim_values = [], []

    for i in range(len(low_files)):
        low_path = os.path.join(freq_dir, low_files[i])
        high_path = os.path.join(freq_dir, high_files[i])

        # Load LR
        lr_img = np.array(Image.open(low_path).resize(IMAGE_SIZE), dtype=np.float32)/255.0
        lr_tensor = torch.tensor(lr_img).unsqueeze(0).unsqueeze(0).to(DEVICE)

        # Predict
        with torch.no_grad():
            pred = model(lr_tensor).cpu().numpy()

        # remove batch/channel dims and scale to 0-255
        pred_img = np.squeeze(pred)
        pred_img = (pred_img*255.0).clip(0,255).astype(np.uint8)
        pred_img = cv2.normalize(pred_img, None, 0, 255, cv2.NORM_MINMAX)

        # save prediction
        pred_path = os.path.join(pred_dir, f"pred_{i+1}.png")
        Image.fromarray(pred_img).save(pred_path)

        # Load GT
        gt_img = cv2.imread(high_path, cv2.IMREAD_GRAYSCALE)
        if pred_img.shape != gt_img.shape:
            pred_img = cv2.resize(pred_img, (gt_img.shape[1], gt_img.shape[0]))

        # Metrics
        psnr_values.append(calculate_psnr(pred_img, gt_img))
        ssim_values.append(calculate_ssim(pred_img, gt_img))

    if psnr_values and ssim_values:
        print(f"SSIM: avg={np.mean(ssim_values):.4f}, min={np.min(ssim_values):.4f}, max={np.max(ssim_values):.4f}")
        print(f"PSNR: avg={np.mean(psnr_values):.2f} dB, min={np.min(psnr_values):.2f} dB, max={np.max(psnr_values):.2f} dB")
    else:
        print("[Error] No valid pairs processed.")


---- Evaluating png_images_650M_1083M ----
SSIM: avg=0.9622, min=0.9309, max=0.9731
PSNR: avg=35.91 dB, min=30.69 dB, max=37.21 dB

---- Evaluating png_images_700M_1167M ----
SSIM: avg=0.9871, min=0.8320, max=0.9909
PSNR: avg=37.91 dB, min=32.60 dB, max=38.81 dB

---- Evaluating png_images_800M_1333M ----
SSIM: avg=0.9858, min=0.9795, max=0.9887
PSNR: avg=36.98 dB, min=34.60 dB, max=38.07 dB

---- Evaluating png_images_850M_1416M ----
SSIM: avg=0.9570, min=0.9341, max=0.9656
PSNR: avg=34.74 dB, min=31.40 dB, max=35.94 dB

---- Evaluating png_images_900M_1500M ----
SSIM: avg=0.9309, min=0.8809, max=0.9470
PSNR: avg=34.48 dB, min=31.15 dB, max=35.51 dB


In [4]:
import os
import re
import numpy as np
from PIL import Image
import cv2
from math import log10
from skimage.metrics import structural_similarity as ssim

import torch
import torch.nn as nn
import torch.nn.functional as F

# =========================
# CONFIG
# =========================
DATASET_DIR = r"C:\Preet\clean_validation dataset\png_images_650M_1083M_1800M"
MODEL_PATH = r"C:\Preet\clean_paired_bscans\best_transunet_huber+ssim.pth"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMAGE_SIZE = (256, 256)

# =========================
# MODEL BLOCKS
# =========================
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True)
        )
    def forward(self, x): return self.conv(x)

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads=4, ff_dim=512, dropout=0.1):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, embed_dim)
        )
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.drop = nn.Dropout(dropout)
    def forward(self, x):
        attn_out, _ = self.attn(x, x, x)
        x = self.norm1(x + self.drop(attn_out))
        ff_out = self.ff(x)
        x = self.norm2(x + self.drop(ff_out))
        return x

class TransformerBottleneck(nn.Module):
    def __init__(self, in_ch, patch_size=16, num_layers=2, num_heads=4):
        super().__init__()
        self.patch_size = patch_size
        self.embed_dim = in_ch
        self.flatten = nn.Unfold(kernel_size=patch_size, stride=patch_size)
        self.project = nn.Linear(patch_size*patch_size*in_ch, in_ch)
        self.transformer = nn.Sequential(*[TransformerBlock(embed_dim=in_ch, num_heads=num_heads) for _ in range(num_layers)])
        self.reconstruct = nn.Linear(in_ch, patch_size*patch_size*in_ch)
        self.fold = None
    def forward(self, x):
        B,C,H,W = x.shape
        patches = self.flatten(x).transpose(1,2)
        patches = self.project(patches)
        patches = self.transformer(patches)
        patches = self.reconstruct(patches).transpose(1,2)
        if self.fold is None:
            self.fold = nn.Fold(output_size=(H,W), kernel_size=self.patch_size, stride=self.patch_size)
        x_reconstructed = self.fold(patches)
        return x_reconstructed

class TransUNet(nn.Module):
    def __init__(self, in_ch=1, out_ch=1, base_ch=64):
        super().__init__()
        self.enc1 = ConvBlock(in_ch, base_ch)
        self.enc2 = ConvBlock(base_ch, base_ch*2)
        self.enc3 = ConvBlock(base_ch*2, base_ch*4)
        self.pool = nn.MaxPool2d(2)
        self.bottleneck = TransformerBottleneck(base_ch*4, patch_size=16, num_layers=2, num_heads=4)
        self.dec3 = ConvBlock(base_ch*4 + base_ch*4, base_ch*2)
        self.dec2 = ConvBlock(base_ch*2 + base_ch*2, base_ch)
        self.out_conv = nn.Conv2d(base_ch, out_ch, kernel_size=1)
    def forward(self, x):
        e1 = self.enc1(x); e2 = self.enc2(self.pool(e1)); e3 = self.enc3(self.pool(e2))
        b = self.bottleneck(self.pool(e3))
        d3 = F.interpolate(b, size=e3.shape[2:], mode='bilinear', align_corners=False)
        d3 = self.dec3(torch.cat([d3, e3], dim=1))
        d2 = F.interpolate(d3, size=e2.shape[2:], mode='bilinear', align_corners=False)
        d2 = self.dec2(torch.cat([d2, e2], dim=1))
        d1 = F.interpolate(d2, size=e1.shape[2:], mode='bilinear', align_corners=False)
        return self.out_conv(d1)

# =========================
# HELPERS
# =========================
def calculate_psnr(img1, img2):
    mse = np.mean((img1 - img2) ** 2)
    if mse == 0: return float('inf')
    return 20 * log10(255.0 / np.sqrt(mse))

def calculate_ssim(img1, img2):
    return ssim(img1, img2, data_range=255)

def numerical_sort(files):
    return sorted(files, key=lambda f: int(re.search(r'\d+', f).group()))

def run_model(img, model):
    tensor = torch.tensor(img, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        pred = model(tensor).cpu().numpy()
    pred = np.squeeze(pred)
    pred = (pred * 255.0).clip(0, 255).astype(np.uint8)
    return cv2.normalize(pred, None, 0, 255, cv2.NORM_MINMAX)

# =========================
# LOAD MODEL
# =========================
model = TransUNet().to(DEVICE)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.eval()

# =========================
# INFERENCE 2-STAGE
# =========================
files_650 = numerical_sort([f for f in os.listdir(DATASET_DIR) if f.endswith("650_bscan.png")])
files_1083 = numerical_sort([f for f in os.listdir(DATASET_DIR) if f.endswith("1083_bscan.png")])
files_1800 = numerical_sort([f for f in os.listdir(DATASET_DIR) if f.endswith("1800_bscan.png")])

psnr_stage1, ssim_stage1 = [], []
psnr_stage2, ssim_stage2 = [], []

pred_dir1 = os.path.join(DATASET_DIR, "predictions_stage1_transunet_huber+ssim")
pred_dir2 = os.path.join(DATASET_DIR, "predictions_stage2_transunet_huber+ssim")
os.makedirs(pred_dir1, exist_ok=True)
os.makedirs(pred_dir2, exist_ok=True)

for i in range(len(files_650)):
    # ---- Stage 1: 650 -> pred -> compare with 1083 ----
    lr_img = np.array(Image.open(os.path.join(DATASET_DIR, files_650[i])).resize(IMAGE_SIZE), dtype=np.float32) / 255.0
    gt_1083 = cv2.imread(os.path.join(DATASET_DIR, files_1083[i]), cv2.IMREAD_GRAYSCALE)

    pred_1083 = run_model(lr_img, model)
    Image.fromarray(pred_1083).save(os.path.join(pred_dir1, f"pred1_{i+1}.png"))

    if pred_1083.shape != gt_1083.shape:
        pred_1083 = cv2.resize(pred_1083, (gt_1083.shape[1], gt_1083.shape[0]))

    psnr_stage1.append(calculate_psnr(pred_1083, gt_1083))
    ssim_stage1.append(calculate_ssim(pred_1083, gt_1083))

    # ---- Stage 2: pred_1083 -> pred -> compare with 1800 ----
    gt_1800 = cv2.imread(os.path.join(DATASET_DIR, files_1800[i]), cv2.IMREAD_GRAYSCALE)
    pred_1083_resized = cv2.resize(pred_1083, IMAGE_SIZE).astype(np.float32)/255.0
    pred_1800 = run_model(pred_1083_resized, model)
    Image.fromarray(pred_1800).save(os.path.join(pred_dir2, f"pred2_{i+1}.png"))

    if pred_1800.shape != gt_1800.shape:
        pred_1800 = cv2.resize(pred_1800, (gt_1800.shape[1], gt_1800.shape[0]))

    psnr_stage2.append(calculate_psnr(pred_1800, gt_1800))
    ssim_stage2.append(calculate_ssim(pred_1800, gt_1800))

# =========================
# RESULTS
# =========================
print("\n---- Stage 1: 650 → pred → compare with 1083 ----")
print(f"SSIM: avg={np.mean(ssim_stage1):.4f}, min={np.min(ssim_stage1):.4f}, max={np.max(ssim_stage1):.4f}")
print(f"PSNR: avg={np.mean(psnr_stage1):.2f} dB, min={np.min(psnr_stage1):.2f} dB, max={np.max(psnr_stage1):.2f} dB")

print("\n---- Stage 2: pred_1083 → pred → compare with 1800 ----")
print(f"SSIM: avg={np.mean(ssim_stage2):.4f}, min={np.min(ssim_stage2):.4f}, max={np.max(ssim_stage2):.4f}")
print(f"PSNR: avg={np.mean(psnr_stage2):.2f} dB, min={np.min(psnr_stage2):.2f} dB, max={np.max(psnr_stage2):.2f} dB")



---- Stage 1: 650 → pred → compare with 1083 ----
SSIM: avg=0.9609, min=0.9387, max=0.9726
PSNR: avg=35.82 dB, min=33.45 dB, max=37.16 dB

---- Stage 2: pred_1083 → pred → compare with 1800 ----
SSIM: avg=0.9243, min=0.8989, max=0.9428
PSNR: avg=31.00 dB, min=29.74 dB, max=32.30 dB


In [5]:
import os
import re
import numpy as np
from PIL import Image
import cv2
from math import log10
from skimage.metrics import structural_similarity as ssim

import torch
import torch.nn as nn
import torch.nn.functional as F

# =========================
# CONFIG
# =========================
DATASET_DIR = r"C:\Preet\clean_validation dataset\png_images_900M_1500M_2500M"
MODEL_PATH = r"C:\Preet\clean_paired_bscans\best_transunet_huber+ssim.pth"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMAGE_SIZE = (256, 256)

# =========================
# MODEL BLOCKS
# =========================
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True)
        )
    def forward(self, x): return self.conv(x)

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads=4, ff_dim=512, dropout=0.1):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, embed_dim)
        )
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.drop = nn.Dropout(dropout)
    def forward(self, x):
        attn_out, _ = self.attn(x, x, x)
        x = self.norm1(x + self.drop(attn_out))
        ff_out = self.ff(x)
        x = self.norm2(x + self.drop(ff_out))
        return x

class TransformerBottleneck(nn.Module):
    def __init__(self, in_ch, patch_size=16, num_layers=2, num_heads=4):
        super().__init__()
        self.patch_size = patch_size
        self.embed_dim = in_ch
        self.flatten = nn.Unfold(kernel_size=patch_size, stride=patch_size)
        self.project = nn.Linear(patch_size*patch_size*in_ch, in_ch)
        self.transformer = nn.Sequential(*[TransformerBlock(embed_dim=in_ch, num_heads=num_heads) for _ in range(num_layers)])
        self.reconstruct = nn.Linear(in_ch, patch_size*patch_size*in_ch)
        self.fold = None
    def forward(self, x):
        B,C,H,W = x.shape
        patches = self.flatten(x).transpose(1,2)
        patches = self.project(patches)
        patches = self.transformer(patches)
        patches = self.reconstruct(patches).transpose(1,2)
        if self.fold is None:
            self.fold = nn.Fold(output_size=(H,W), kernel_size=self.patch_size, stride=self.patch_size)
        x_reconstructed = self.fold(patches)
        return x_reconstructed

class TransUNet(nn.Module):
    def __init__(self, in_ch=1, out_ch=1, base_ch=64):
        super().__init__()
        self.enc1 = ConvBlock(in_ch, base_ch)
        self.enc2 = ConvBlock(base_ch, base_ch*2)
        self.enc3 = ConvBlock(base_ch*2, base_ch*4)
        self.pool = nn.MaxPool2d(2)
        self.bottleneck = TransformerBottleneck(base_ch*4, patch_size=16, num_layers=2, num_heads=4)
        self.dec3 = ConvBlock(base_ch*4 + base_ch*4, base_ch*2)
        self.dec2 = ConvBlock(base_ch*2 + base_ch*2, base_ch)
        self.out_conv = nn.Conv2d(base_ch, out_ch, kernel_size=1)
    def forward(self, x):
        e1 = self.enc1(x); e2 = self.enc2(self.pool(e1)); e3 = self.enc3(self.pool(e2))
        b = self.bottleneck(self.pool(e3))
        d3 = F.interpolate(b, size=e3.shape[2:], mode='bilinear', align_corners=False)
        d3 = self.dec3(torch.cat([d3, e3], dim=1))
        d2 = F.interpolate(d3, size=e2.shape[2:], mode='bilinear', align_corners=False)
        d2 = self.dec2(torch.cat([d2, e2], dim=1))
        d1 = F.interpolate(d2, size=e1.shape[2:], mode='bilinear', align_corners=False)
        return self.out_conv(d1)

# =========================
# HELPERS
# =========================
def calculate_psnr(img1, img2):
    mse = np.mean((img1 - img2) ** 2)
    if mse == 0: return float('inf')
    return 20 * log10(255.0 / np.sqrt(mse))

def calculate_ssim(img1, img2):
    return ssim(img1, img2, data_range=255)

def numerical_sort(files):
    return sorted(files, key=lambda f: int(re.search(r'\d+', f).group()))

def run_model(img, model):
    tensor = torch.tensor(img, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        pred = model(tensor).cpu().numpy()
    pred = np.squeeze(pred)
    pred = (pred * 255.0).clip(0, 255).astype(np.uint8)
    return cv2.normalize(pred, None, 0, 255, cv2.NORM_MINMAX)

# =========================
# LOAD MODEL
# =========================
model = TransUNet().to(DEVICE)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.eval()

# =========================
# INFERENCE 2-STAGE
# =========================
files_650 = numerical_sort([f for f in os.listdir(DATASET_DIR) if f.endswith("900_bscan.png")])
files_1083 = numerical_sort([f for f in os.listdir(DATASET_DIR) if f.endswith("1500_bscan.png")])
files_1800 = numerical_sort([f for f in os.listdir(DATASET_DIR) if f.endswith("2500_bscan.png")])

psnr_stage1, ssim_stage1 = [], []
psnr_stage2, ssim_stage2 = [], []

pred_dir1 = os.path.join(DATASET_DIR, "predictions_stage1_transunet_huber+ssim")
pred_dir2 = os.path.join(DATASET_DIR, "predictions_stage2_transunet_huber+ssim")
os.makedirs(pred_dir1, exist_ok=True)
os.makedirs(pred_dir2, exist_ok=True)

for i in range(len(files_650)):
    # ---- Stage 1: 650 -> pred -> compare with 1083 ----
    lr_img = np.array(Image.open(os.path.join(DATASET_DIR, files_650[i])).resize(IMAGE_SIZE), dtype=np.float32) / 255.0
    gt_1083 = cv2.imread(os.path.join(DATASET_DIR, files_1083[i]), cv2.IMREAD_GRAYSCALE)

    pred_1083 = run_model(lr_img, model)
    Image.fromarray(pred_1083).save(os.path.join(pred_dir1, f"pred1_{i+1}.png"))

    if pred_1083.shape != gt_1083.shape:
        pred_1083 = cv2.resize(pred_1083, (gt_1083.shape[1], gt_1083.shape[0]))

    psnr_stage1.append(calculate_psnr(pred_1083, gt_1083))
    ssim_stage1.append(calculate_ssim(pred_1083, gt_1083))

    # ---- Stage 2: pred_1083 -> pred -> compare with 1800 ----
    gt_1800 = cv2.imread(os.path.join(DATASET_DIR, files_1800[i]), cv2.IMREAD_GRAYSCALE)
    pred_1083_resized = cv2.resize(pred_1083, IMAGE_SIZE).astype(np.float32)/255.0
    pred_1800 = run_model(pred_1083_resized, model)
    Image.fromarray(pred_1800).save(os.path.join(pred_dir2, f"pred2_{i+1}.png"))

    if pred_1800.shape != gt_1800.shape:
        pred_1800 = cv2.resize(pred_1800, (gt_1800.shape[1], gt_1800.shape[0]))

    psnr_stage2.append(calculate_psnr(pred_1800, gt_1800))
    ssim_stage2.append(calculate_ssim(pred_1800, gt_1800))

# =========================
# RESULTS
# =========================
print("\n---- Stage 1: 900 → pred → compare with 1500 ----")
print(f"SSIM: avg={np.mean(ssim_stage1):.4f}, min={np.min(ssim_stage1):.4f}, max={np.max(ssim_stage1):.4f}")
print(f"PSNR: avg={np.mean(psnr_stage1):.2f} dB, min={np.min(psnr_stage1):.2f} dB, max={np.max(psnr_stage1):.2f} dB")

print("\n---- Stage 2: pred_1500 → pred → compare with 2500 ----")
print(f"SSIM: avg={np.mean(ssim_stage2):.4f}, min={np.min(ssim_stage2):.4f}, max={np.max(ssim_stage2):.4f}")
print(f"PSNR: avg={np.mean(psnr_stage2):.2f} dB, min={np.min(psnr_stage2):.2f} dB, max={np.max(psnr_stage2):.2f} dB")



---- Stage 1: 900 → pred → compare with 1500 ----
SSIM: avg=0.7474, min=0.5995, max=0.8695
PSNR: avg=29.93 dB, min=28.32 dB, max=30.53 dB

---- Stage 2: pred_1500 → pred → compare with 2500 ----
SSIM: avg=0.7693, min=0.6452, max=0.8678
PSNR: avg=28.72 dB, min=27.50 dB, max=29.59 dB


for testing 400Mhz dataset on trained 750Mhz model

In [None]:
import os
import re
import numpy as np
from PIL import Image
import cv2
from math import log10
from skimage.metrics import structural_similarity as ssim

import torch
import torch.nn as nn
import torch.nn.functional as F



MODEL_PATH = r"C:\Preet\clean_paired_bscans\best_transunet_huber+ssim.pth"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMAGE_SIZE = (256, 256)

# =========================
# MODEL BLOCKS
# =========================
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True)
        )
    def forward(self, x): return self.conv(x)

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads=4, ff_dim=512, dropout=0.1):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, embed_dim)
        )
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.drop = nn.Dropout(dropout)
    def forward(self, x):
        attn_out, _ = self.attn(x, x, x)
        x = self.norm1(x + self.drop(attn_out))
        ff_out = self.ff(x)
        x = self.norm2(x + self.drop(ff_out))
        return x

class TransformerBottleneck(nn.Module):
    def __init__(self, in_ch, patch_size=16, num_layers=2, num_heads=4):
        super().__init__()
        self.patch_size = patch_size
        self.embed_dim = in_ch
        self.flatten = nn.Unfold(kernel_size=patch_size, stride=patch_size)
        self.project = nn.Linear(patch_size*patch_size*in_ch, in_ch)
        self.transformer = nn.Sequential(*[TransformerBlock(embed_dim=in_ch, num_heads=num_heads) for _ in range(num_layers)])
        self.reconstruct = nn.Linear(in_ch, patch_size*patch_size*in_ch)
        self.fold = None
    def forward(self, x):
        B,C,H,W = x.shape
        patches = self.flatten(x).transpose(1,2)
        patches = self.project(patches)
        patches = self.transformer(patches)
        patches = self.reconstruct(patches).transpose(1,2)
        if self.fold is None:
            self.fold = nn.Fold(output_size=(H,W), kernel_size=self.patch_size, stride=self.patch_size)
        x_reconstructed = self.fold(patches)
        return x_reconstructed

class TransUNet(nn.Module):
    def __init__(self, in_ch=1, out_ch=1, base_ch=64):
        super().__init__()
        self.enc1 = ConvBlock(in_ch, base_ch)
        self.enc2 = ConvBlock(base_ch, base_ch*2)
        self.enc3 = ConvBlock(base_ch*2, base_ch*4)
        self.pool = nn.MaxPool2d(2)
        self.bottleneck = TransformerBottleneck(base_ch*4, patch_size=16, num_layers=2, num_heads=4)
        self.dec3 = ConvBlock(base_ch*4 + base_ch*4, base_ch*2)
        self.dec2 = ConvBlock(base_ch*2 + base_ch*2, base_ch)
        self.out_conv = nn.Conv2d(base_ch, out_ch, kernel_size=1)
    def forward(self, x):
        e1 = self.enc1(x); e2 = self.enc2(self.pool(e1)); e3 = self.enc3(self.pool(e2))
        b = self.bottleneck(self.pool(e3))
        d3 = F.interpolate(b, size=e3.shape[2:], mode='bilinear', align_corners=False)
        d3 = self.dec3(torch.cat([d3, e3], dim=1))
        d2 = F.interpolate(d3, size=e2.shape[2:], mode='bilinear', align_corners=False)
        d2 = self.dec2(torch.cat([d2, e2], dim=1))
        d1 = F.interpolate(d2, size=e1.shape[2:], mode='bilinear', align_corners=False)
        return self.out_conv(d1)

# =========================
# HELPERS
# =========================
def calculate_psnr(img1, img2):
    mse = np.mean((img1 - img2) ** 2)
    if mse == 0: return float('inf')
    return 20 * log10(255.0 / np.sqrt(mse))

def calculate_ssim(img1, img2):
    return ssim(img1, img2, data_range=255)

def numerical_sort(files):
    return sorted(files, key=lambda f: int(re.search(r'\d+', f).group()))

# =========================
# LOAD MODEL
# =========================
model = TransUNet().to(DEVICE)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.eval()



TEST_DIR = r"C:\Preet\400_670_Dataset"

print(f"\n---- Evaluating {TEST_DIR} ----")
pred_dir = os.path.join(TEST_DIR, "predictions_unet+RB+AG+hybrid")
os.makedirs(pred_dir, exist_ok=True)

low_files = numerical_sort([f for f in os.listdir(TEST_DIR) if f.endswith("_l.png")])
high_files = numerical_sort([f for f in os.listdir(TEST_DIR) if f.endswith("_h.png")])

psnr_values, ssim_values = [], []

for i in range(len(low_files)):
    low_path = os.path.join(TEST_DIR, low_files[i])
    high_path = os.path.join(TEST_DIR, high_files[i])

    # Load LR
    lr_img = np.array(Image.open(low_path).resize(IMAGE_SIZE), dtype=np.float32) / 255.0
    lr_tensor = torch.tensor(lr_img).unsqueeze(0).unsqueeze(0).to(DEVICE)

    # Predict
    with torch.no_grad():
        pred = model(lr_tensor).cpu().squeeze().numpy()

    # Scale to 0–255
    pred_img = (np.clip(pred, 0.0, 1.0) * 255.0).astype(np.uint8)
    pred_path = os.path.join(pred_dir, f"pred_{i+1}.png")
    Image.fromarray(pred_img).save(pred_path)

    # Load GT
    gt_img = cv2.imread(high_path, cv2.IMREAD_GRAYSCALE)
    if pred_img.shape != gt_img.shape:
        pred_img = cv2.resize(pred_img, (gt_img.shape[1], gt_img.shape[0]))

    # Metrics
    psnr_values.append(calculate_psnr(pred_img, gt_img))
    ssim_values.append(calculate_ssim(pred_img, gt_img))

if psnr_values and ssim_values:
    print(f"SSIM: avg={np.mean(ssim_values):.4f}, min={np.min(ssim_values):.4f}, max={np.max(ssim_values):.4f}")
    print(f"PSNR: avg={np.mean(psnr_values):.2f} dB, min={np.min(psnr_values):.2f} dB, max={np.max(psnr_values):.2f} dB")
else:
    print("[Error] No valid *_l.png and *_h.png pairs found.")


for testing 400Mhz on trained 750Mhz

In [2]:
import os
import re
import numpy as np
from PIL import Image
import cv2
from math import log10
from skimage.metrics import structural_similarity as ssim

import torch
import torch.nn as nn
import torch.nn.functional as F

TEST_DIR = r"C:\Preet\400_670_Dataset"
MODEL_PATH = r"C:\Preet\clean_paired_bscans\best_transunet_huber+ssim.pth"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMAGE_SIZE = (256, 256)

# =========================
# MODEL BLOCKS
# =========================
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True)
        )
    def forward(self, x): return self.conv(x)

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads=4, ff_dim=512, dropout=0.1):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, embed_dim)
        )
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.drop = nn.Dropout(dropout)
    def forward(self, x):
        attn_out, _ = self.attn(x, x, x)
        x = self.norm1(x + self.drop(attn_out))
        ff_out = self.ff(x)
        x = self.norm2(x + self.drop(ff_out))
        return x

class TransformerBottleneck(nn.Module):
    def __init__(self, in_ch, patch_size=16, num_layers=2, num_heads=4):
        super().__init__()
        self.patch_size = patch_size
        self.embed_dim = in_ch
        self.flatten = nn.Unfold(kernel_size=patch_size, stride=patch_size)
        self.project = nn.Linear(patch_size*patch_size*in_ch, in_ch)
        self.transformer = nn.Sequential(*[TransformerBlock(embed_dim=in_ch, num_heads=num_heads) for _ in range(num_layers)])
        self.reconstruct = nn.Linear(in_ch, patch_size*patch_size*in_ch)
        self.fold = None
    def forward(self, x):
        B,C,H,W = x.shape
        patches = self.flatten(x).transpose(1,2)
        patches = self.project(patches)
        patches = self.transformer(patches)
        patches = self.reconstruct(patches).transpose(1,2)
        if self.fold is None:
            self.fold = nn.Fold(output_size=(H,W), kernel_size=self.patch_size, stride=self.patch_size)
        x_reconstructed = self.fold(patches)
        return x_reconstructed

class TransUNet(nn.Module):
    def __init__(self, in_ch=1, out_ch=1, base_ch=64):
        super().__init__()
        self.enc1 = ConvBlock(in_ch, base_ch)
        self.enc2 = ConvBlock(base_ch, base_ch*2)
        self.enc3 = ConvBlock(base_ch*2, base_ch*4)
        self.pool = nn.MaxPool2d(2)
        self.bottleneck = TransformerBottleneck(base_ch*4, patch_size=16, num_layers=2, num_heads=4)
        self.dec3 = ConvBlock(base_ch*4 + base_ch*4, base_ch*2)
        self.dec2 = ConvBlock(base_ch*2 + base_ch*2, base_ch)
        self.out_conv = nn.Conv2d(base_ch, out_ch, kernel_size=1)
    def forward(self, x):
        e1 = self.enc1(x); e2 = self.enc2(self.pool(e1)); e3 = self.enc3(self.pool(e2))
        b = self.bottleneck(self.pool(e3))
        d3 = F.interpolate(b, size=e3.shape[2:], mode='bilinear', align_corners=False)
        d3 = self.dec3(torch.cat([d3, e3], dim=1))
        d2 = F.interpolate(d3, size=e2.shape[2:], mode='bilinear', align_corners=False)
        d2 = self.dec2(torch.cat([d2, e2], dim=1))
        d1 = F.interpolate(d2, size=e1.shape[2:], mode='bilinear', align_corners=False)
        return self.out_conv(d1)

# =========================
# HELPERS
# =========================
def calculate_psnr(img1, img2):
    mse = np.mean((img1 - img2) ** 2)
    if mse == 0: return float('inf')
    return 20 * log10(255.0 / np.sqrt(mse))

def calculate_ssim(img1, img2):
    return ssim(img1, img2, data_range=255)

def numerical_sort(files):
    return sorted(files, key=lambda f: int(re.search(r'\d+', f).group()))

# =========================
# LOAD MODEL
# =========================
model = TransUNet().to(DEVICE)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.eval()




print(f"\n---- Evaluating {TEST_DIR} ----")
pred_dir = os.path.join(TEST_DIR, "predictions_transunet_huber")
os.makedirs(pred_dir, exist_ok=True)

low_files = numerical_sort([f for f in os.listdir(TEST_DIR) if f.endswith("_l.png")])
high_files = numerical_sort([f for f in os.listdir(TEST_DIR) if f.endswith("_h.png")])

psnr_values, ssim_values = [], []

for i in range(len(low_files)):
    low_path = os.path.join(TEST_DIR, low_files[i])
    high_path = os.path.join(TEST_DIR, high_files[i])

    # Load LR
    lr_img = np.array(Image.open(low_path).resize(IMAGE_SIZE), dtype=np.float32) / 255.0
    lr_tensor = torch.tensor(lr_img).unsqueeze(0).unsqueeze(0).to(DEVICE)

    # Predict
    with torch.no_grad():
        pred = model(lr_tensor).cpu().squeeze().numpy()

    # Scale to 0–255
    pred_img = (np.clip(pred, 0.0, 1.0) * 255.0).astype(np.uint8)
    pred_path = os.path.join(pred_dir, f"pred_{i+1}.png")
    Image.fromarray(pred_img).save(pred_path)

    # Load GT
    gt_img = cv2.imread(high_path, cv2.IMREAD_GRAYSCALE)
    if pred_img.shape != gt_img.shape:
        pred_img = cv2.resize(pred_img, (gt_img.shape[1], gt_img.shape[0]))

    # Metrics
    psnr_values.append(calculate_psnr(pred_img, gt_img))
    ssim_values.append(calculate_ssim(pred_img, gt_img))

if psnr_values and ssim_values:
    print(f"SSIM: avg={np.mean(ssim_values):.4f}, min={np.min(ssim_values):.4f}, max={np.max(ssim_values):.4f}")
    print(f"PSNR: avg={np.mean(psnr_values):.2f} dB, min={np.min(psnr_values):.2f} dB, max={np.max(psnr_values):.2f} dB")
else:
    print("[Error] No valid *_l.png and *_h.png pairs found.")



---- Evaluating C:\Preet\400_670_Dataset ----
SSIM: avg=0.7523, min=0.6270, max=0.8185
PSNR: avg=29.69 dB, min=29.07 dB, max=30.24 dB
