In [1]:
import torch
from os import path
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as transforms

import h5py
import numpy as np
import wget
import matplotlib.pyplot as plt
from skimage.metrics import peak_signal_noise_ratio as psnr, structural_similarity as ssim
import csv
from torchvision.utils import save_image

import os
import glob
from PIL import Image
import numpy as np
import torch
import csv

import lpips



In [2]:
##Check if data path is correct
data_path = "/home/sulaimon/EXPERIMENT/23/images_001"
image_files = glob.glob(os.path.join(data_path, '*.png'))
print(f'Found {len(image_files)} Images.')

Found 14999 Images.


In [None]:
#=====Load Data=====
def load_xray_image(data_path):
    img = Image.open(data_path).convert('L')  # Ensure grayscale
    img = img.resize((256, 256))              # Resize to 256x256
    img_array = np.array(img, dtype=np.float32) / 255.0  # Normalize to [0,1]
    img_array = np.expand_dims(img_array, axis=0)  # Add channel dimension: (1, 256, 256)
    return img_array

# Load all images into a numpy array
InputImages = np.array([load_xray_image(f) for f in image_files])
print("Dataset Shape:", InputImages.shape)

Dataset Shape: (14999, 1, 256, 256)


In [None]:
#=====Add Noise=====
def add_gaussian_noise(images, mean=0.1, stddev=0.1):
    noise = torch.normal( mean, std=stddev, size=images.shape)
    noisy_images = images + noise
    return torch.clamp(noisy_images, 0., 1.)

InputImages = torch.tensor(InputImages, dtype=torch.float32)
noisy_images = add_gaussian_noise(InputImages)


In [None]:
#=====Data Loader=====
class X_rayDataset(Dataset):
    def __init__(self, noisy, clean):
        self.noisy = noisy
        self.clean = clean

    def __len__(self):
        return len(self.noisy)
    
    def __getitem__(self, idx):
        return self.noisy[idx], self.clean[idx]


##Splitting and Data Loader Creation
dataset = X_rayDataset(noisy_images, InputImages)

#Ensuring reproducibility
torch.manual_seed(42)
np.random.seed(42)

train_size = int(0.5 * len(dataset))
val_size = int(0.33 * len(dataset))
test_size = len(dataset) - train_size - val_size

train_set, val_set, test_set = random_split(dataset, [train_size, val_size, test_size])

train_loader = DataLoader(train_set, batch_size=32, shuffle=True)
val_loader = DataLoader(val_set, batch_size=32)
test_loader = DataLoader(test_set, batch_size=32,shuffle=False)



##Verify the samples
def print_dataset_shapes(loader, name):
    data_iter = iter(loader)
    noisy_sample, clean_sample = next(data_iter)
    print(f"{name} set: Noisy shape: {noisy_sample.shape}, Clean shape: {clean_sample.shape}")


print_dataset_shapes(train_loader, "Train")
print_dataset_shapes(val_loader, "Validation")
print_dataset_shapes(test_loader, "Test")


print(f"Train set size: {len(train_set)}")
print(f"Val Set: {len(val_set)}")
print(f"Test Set: {len(test_set)}")

Train set: Noisy shape: torch.Size([32, 1, 256, 256]), Clean shape: torch.Size([32, 1, 256, 256])
Validation set: Noisy shape: torch.Size([32, 1, 256, 256]), Clean shape: torch.Size([32, 1, 256, 256])
Test set: Noisy shape: torch.Size([32, 1, 256, 256]), Clean shape: torch.Size([32, 1, 256, 256])
Train set size: 7499
Val Set: 4949
Test Set: 2551


In [None]:
#=====UNet++ Model Definition=====

#ConvBlock building: Two 3x3 convolutions with same padding and ReLU
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.block(x)

class UNetplusplus(nn.Module):
    # 5 levels with single final output at x40
    def __init__(self, input_channels=1, output_channels=1, base_ch=64):  # <-- CHANGED HERE
        super(UNetplusplus, self).__init__()

        Nc = [base_ch, base_ch*2, base_ch*4, base_ch*8, base_ch*16]

        # Encoder Xi0
        self.conv0_0 = ConvBlock(input_channels, Nc[0])  # X0_0
        self.conv1_0 = ConvBlock(Nc[0], Nc[1])  # X1_0
        self.conv2_0 = ConvBlock(Nc[1], Nc[2])  # X2_0
        self.conv3_0 = ConvBlock(Nc[2], Nc[3])  # X3_0
        self.conv4_0 = ConvBlock(Nc[3], Nc[4])  # X4_0

        self.pool = nn.MaxPool2d(2,2)

        # Decoder Level 1 (Xi1)
        self.conv0_1 = ConvBlock(Nc[0] + Nc[1], Nc[0])
        self.conv1_1 = ConvBlock(Nc[1] + Nc[2], Nc[1])
        self.conv2_1 = ConvBlock(Nc[2] + Nc[3], Nc[2])
        self.conv3_1 = ConvBlock(Nc[3] + Nc[4], Nc[3])

        # Decoder Level 2 (Xi2)
        self.conv0_2 = ConvBlock(Nc[0]*2 + Nc[1], Nc[0])
        self.conv1_2 = ConvBlock(Nc[1]*2 + Nc[2], Nc[1])
        self.conv2_2 = ConvBlock(Nc[2]*2 + Nc[3], Nc[2])

        # Decoder Level 3 (Xi3)
        self.conv0_3 = ConvBlock(Nc[0]*3 + Nc[1], Nc[0])
        self.conv1_3 = ConvBlock(Nc[1]*3 + Nc[2], Nc[1])

        # Decoder Level 4 (Xi4)
        self.conv0_4 = ConvBlock(Nc[0]*4 + Nc[1], Nc[0])

        # Final Convolutions (only 1 output channel)
        self.final_X01 = nn.Conv2d(Nc[0], output_channels, kernel_size=1)
        self.final_X02 = nn.Conv2d(Nc[0], output_channels, kernel_size=1)
        self.final_X03 = nn.Conv2d(Nc[0], output_channels, kernel_size=1)
        self.final_X04 = nn.Conv2d(Nc[0], output_channels, kernel_size=1)

    def forward(self, x):
        # Encoder
        X0_0 = self.conv0_0(x)
        X1_0 = self.conv1_0(self.pool(X0_0))
        X2_0 = self.conv2_0(self.pool(X1_0))
        X3_0 = self.conv3_0(self.pool(X2_0))
        X4_0 = self.conv4_0(self.pool(X3_0))

        # Decoder Level 1
        X0_1 = self.conv0_1(torch.cat([X0_0, F.interpolate(X1_0, scale_factor=2, mode='bilinear', align_corners=True)], 1))
        X1_1 = self.conv1_1(torch.cat([X1_0, F.interpolate(X2_0, scale_factor=2, mode='bilinear', align_corners=True)], 1))
        X2_1 = self.conv2_1(torch.cat([X2_0, F.interpolate(X3_0, scale_factor=2, mode='bilinear', align_corners=True)], 1))
        X3_1 = self.conv3_1(torch.cat([X3_0, F.interpolate(X4_0, scale_factor=2, mode='bilinear', align_corners=True)], 1))

        # Level 2
        X0_2 = self.conv0_2(torch.cat([X0_0, X0_1, F.interpolate(X1_1, scale_factor=2, mode='bilinear', align_corners=True)], 1))
        X1_2 = self.conv1_2(torch.cat([X1_0, X1_1, F.interpolate(X2_1, scale_factor=2, mode='bilinear', align_corners=True)], 1))
        X2_2 = self.conv2_2(torch.cat([X2_0, X2_1, F.interpolate(X3_1, scale_factor=2, mode='bilinear', align_corners=True)], 1))

        # Level 3
        X0_3 = self.conv0_3(torch.cat([X0_0, X0_1, X0_2, F.interpolate(X1_2, scale_factor=2, mode='bilinear', align_corners=True)], 1))
        X1_3 = self.conv1_3(torch.cat([X1_0, X1_1, X1_2, F.interpolate(X2_2, scale_factor=2, mode='bilinear', align_corners=True)], 1))

        # Level 4
        X0_4 = self.conv0_4(torch.cat([X0_0, X0_1, X0_2, X0_3, F.interpolate(X1_3, scale_factor=2, mode='bilinear', align_corners=True)], 1))

        output = self.final_X04(X0_4)

        return {"final": output, "X01": self.final_X01(X0_1), "X02": self.final_X02(X0_2), "X03": self.final_X03(X0_3), "X04": output}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNetplusplus(input_channels=1, output_channels=1, base_ch=64)
# Wrap with DataParallel if more than one GPU
if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs with DataParallel")
    model = nn.DataParallel(model)

model.to(device)

Using 2 GPUs with DataParallel


DataParallel(
  (module): UNetplusplus(
    (conv0_0): ConvBlock(
      (block): Sequential(
        (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace=True)
        (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (4): ReLU(inplace=True)
      )
    )
    (conv1_0): ConvBlock(
      (block): Sequential(
        (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace=True)
        (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (4): ReLU(inplace=True)
      )
    )
    (conv2_0): ConvBlock(
      (block): Sequential(
        (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace=True)
        (2): Conv2d(256, 25

In [None]:
#=====Metrics=====
def psnr_metric(y_true, y_pred):
    mse = F.mse_loss(y_pred, y_true)
    return 10 * torch.log10(1.0 / mse)

def ssim_metric(y_true, y_pred):
    y_true = y_true.detach().cpu().numpy()
    y_pred = y_pred.detach().cpu().numpy()
    ssim_scores = []
    for i in range(y_true.shape[0]):  # loop over batch
        ssim_score = ssim(y_true[i,0], y_pred[i,0], data_range=1.0)  # Only first (and only) channel
        ssim_scores.append(ssim_score)
    return torch.tensor(ssim_scores).mean()

optimizer = optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.L1Loss()


In [None]:
#=====Model Training=====
def train_model(model, train_loader, val_loader, epochs=50, save_path="best_model_unetpp_DP_10.pth"):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    model.to(device)  # Move model to GPU if available

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    loss_fn = nn.L1Loss()  # Use L1Loss as defined earlier

    best_loss = float('inf')

    for epoch in range(epochs):
        model.train()
        train_loss = 0.0

        for noisy, clean in train_loader:
            noisy, clean = noisy.to(device), clean.to(device)

            optimizer.zero_grad()
            outputs = model(noisy)["final"]
            loss = loss_fn(outputs, clean)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

        val_loss = 0.0
        model.eval()
        with torch.no_grad():
            for noisy, clean in val_loader:
                noisy, clean = noisy.to(device), clean.to(device)

                outputs = model(noisy)["final"]
                loss = loss_fn(outputs, clean)
                val_loss += loss.item()

        avg_train_loss = train_loss / len(train_loader)
        avg_val_loss = val_loss / len(val_loader)

        print(f"Epoch {epoch+1}/{epochs} => Train Loss: {avg_train_loss:.6f}, Validation Loss: {avg_val_loss:.6f}")

        # Save best model
        if avg_val_loss < best_loss:
            best_loss = avg_val_loss
            torch.save(model.state_dict(), save_path)
            print(f"Model saved at epoch {epoch+1} with best validation loss: {best_loss:.6f}")

# Example: Call training
train_model(model, train_loader, val_loader)


Using device: cuda
Epoch 1/50 => Train Loss: 0.055844, Validation Loss: 0.022192
Model saved at epoch 1 with best validation loss: 0.022192
Epoch 2/50 => Train Loss: 0.027498, Validation Loss: 0.033289
Epoch 3/50 => Train Loss: 0.025709, Validation Loss: 0.018470
Model saved at epoch 3 with best validation loss: 0.018470
Epoch 4/50 => Train Loss: 0.023079, Validation Loss: 0.027808
Epoch 5/50 => Train Loss: 0.023904, Validation Loss: 0.018481
Epoch 6/50 => Train Loss: 0.022512, Validation Loss: 0.021907
Epoch 7/50 => Train Loss: 0.021316, Validation Loss: 0.023252
Epoch 8/50 => Train Loss: 0.022704, Validation Loss: 0.020045
Epoch 9/50 => Train Loss: 0.020289, Validation Loss: 0.019766
Epoch 10/50 => Train Loss: 0.019892, Validation Loss: 0.020638
Epoch 11/50 => Train Loss: 0.019538, Validation Loss: 0.027600
Epoch 12/50 => Train Loss: 0.019505, Validation Loss: 0.028199
Epoch 13/50 => Train Loss: 0.019699, Validation Loss: 0.019559
Epoch 14/50 => Train Loss: 0.019304, Validation Loss:

In [None]:
# Initialize LPIPS model globally 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
loss_fn_lpips = lpips.LPIPS(net='alex').to(device)

#=====Testing=====
def test_model(model, test_loader, device):
    model.eval()
    model.to(device)

    total_psnr = 0.0
    total_ssim = 0.0
    total_lpips = 0.0
    num_samples = 0

    with torch.no_grad():
        for noisy, clean in test_loader:
            noisy, clean = noisy.to(device), clean.to(device)
            outputs = model(noisy)["final"]

            for i in range(noisy.shape[0]):
                y_true = clean[i,0].cpu().numpy()   # (H,W) for grayscale
                y_pred = outputs[i,0].cpu().numpy() # (H,W) for grayscale

                total_psnr += psnr(y_true, y_pred, data_range=1.0)
                total_ssim += ssim(y_true, y_pred, data_range=1.0, win_size=3)  # multichannel=False by default

                # LPIPS: convert single-channel to 3-channel and normalize to [-1, 1]
                out_img = outputs[i].unsqueeze(0)  # [1, 1, H, W]
                tgt_img = clean[i].unsqueeze(0)    # [1, 1, H, W]

                if out_img.shape[1] == 1:
                    out_img = out_img.repeat(1, 3, 1, 1)
                    tgt_img = tgt_img.repeat(1, 3, 1, 1)

                out_img = (out_img * 2) - 1
                tgt_img = (tgt_img * 2) - 1

                lpips_score = loss_fn_lpips(out_img.to(device), tgt_img.to(device))
                total_lpips += lpips_score.item()

                num_samples += 1

    avg_psnr = total_psnr / num_samples
    avg_ssim = total_ssim / num_samples
    avg_lpips = total_lpips / num_samples

    print(f"Average PSNR: {avg_psnr:.2f} dB")
    print(f"Average SSIM: {avg_ssim:.4f}")
    print(f"Average LPIPS: {avg_lpips:.4f}")
    print(f"Test samples evaluated: {num_samples}")


model.load_state_dict(torch.load('best_model_unetpp_DP_10.pth', map_location=device))

test_model(model, test_loader, device)

Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]




Loading model from: /home/sulaimon/.local/lib/python3.10/site-packages/lpips/weights/v0.1/alex.pth
Average PSNR: 34.09 dB
Average SSIM: 0.9084
Average LPIPS: 0.1565
Test samples evaluated: 2551


In [None]:
#=====Save Sample data (Clean, Noisy, Denoised)=====
def save_test_image_triplets_from_model(model, test_loader, device="cpu", output_dir="saved_unet++DT_images"):
    # Ensure output directory exists
    os.makedirs(output_dir, exist_ok=True)

    # Get one batch from the test set
    noisy_imgs, clean_imgs = next(iter(test_loader))
    noisy_imgs = noisy_imgs.to(device)
    clean_imgs = clean_imgs.to(device)

    # Run inference
    model.eval()
    with torch.no_grad():
        output_imgs = model(noisy_imgs)["final"]  # Assuming model returns dict with key "final"

    # Move to CPU for saving
    noisy_imgs = noisy_imgs.cpu()
    output_imgs = output_imgs.cpu()
    clean_imgs = clean_imgs.cpu()

    # Save the first 10 triplets
    for i in range(min(10, len(clean_imgs))):
        save_image(noisy_imgs[i], os.path.join(output_dir, f"img_{i+1:02d}_noisy.png"))
        save_image(output_imgs[i], os.path.join(output_dir, f"img_{i+1:02d}_denoised.png"))
        save_image(clean_imgs[i], os.path.join(output_dir, f"img_{i+1:02d}_clean.png"))

    print(f"Saved {min(10, len(clean_imgs)) * 3} images (triplets) to '{output_dir}'")


save_test_image_triplets_from_model(model, test_loader, device)



In [None]:
#=====Save Metrics to CSV=====
def save_metrics_csv_from_model(model, test_loader, device="cpu", csv_path="unetpp_DP_10.csv"):
    model.eval()
    model.to(device)

    loss_fn_lpips = lpips.LPIPS(net='alex').to(device)

    psnr_list = []
    ssim_list = []
    lpips_list = []
    count = 0

    with open(csv_path, mode='w', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(["Image Index", "PSNR", "SSIM", "LPIPS"])

        with torch.no_grad():
            for noisy, clean in test_loader:
                noisy = noisy.to(device)
                clean = clean.to(device)

                outputs = model(noisy)["final"]

                clean_np = clean.squeeze(1).cpu().numpy()     # (B, H, W)
                outputs_np = outputs.squeeze(1).cpu().numpy() # (B, H, W)

                for i in range(clean_np.shape[0]):
                    # PSNR & SSIM
                    psnr_val = psnr(clean_np[i], outputs_np[i], data_range=1.0)
                    ssim_val = ssim(clean_np[i], outputs_np[i], data_range=1.0, win_size=3)

                    # LPIPS: prepare 3-channel normalized input
                    out_img = outputs[i].unsqueeze(0)  # [1, 1, H, W]
                    tgt_img = clean[i].unsqueeze(0)

                    if out_img.shape[1] == 1:
                        out_img = out_img.repeat(1, 3, 1, 1)
                        tgt_img = tgt_img.repeat(1, 3, 1, 1)

                    out_img = (out_img * 2) - 1
                    tgt_img = (tgt_img * 2) - 1

                    lpips_val = loss_fn_lpips(out_img.to(device), tgt_img.to(device)).item()

                    psnr_list.append(psnr_val)
                    ssim_list.append(ssim_val)
                    lpips_list.append(lpips_val)

                    writer.writerow([count + 1, psnr_val, ssim_val, lpips_val])
                    count += 1

        writer.writerow([])  # Blank line
        writer.writerow(["Average", np.mean(psnr_list), np.mean(ssim_list), np.mean(lpips_list)])

    print(f" Saved PSNR/SSIM/LPIPS metrics for {count} images to '{csv_path}'")
    
save_metrics_csv_from_model(model, test_loader, device)

Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /home/sulaimon/.local/lib/python3.10/site-packages/lpips/weights/v0.1/alex.pth
 Saved PSNR/SSIM/LPIPS metrics for 2551 images to 'unetpp_DP_10.csv'


In [None]:
#=====Visualise Clean, Noisy, and Denoised Image Test sets=====
def visualize_results(model, test_loader, device, num_samples=10):
    model.eval()
    model.to(device)

    data_iter = iter(test_loader)
    noisy, clean = next(data_iter)
    noisy, clean = noisy.to(device), clean.to(device)

    with torch.no_grad():
        outputs = model(noisy)["final"]

    # Move tensors to CPU
    noisy = noisy.cpu()
    outputs = outputs.cpu()
    clean = clean.cpu()

    # Cap num_samples if batch smaller
    num_samples = min(num_samples, noisy.shape[0])

    fig, axes = plt.subplots(num_samples, 3, figsize=(12, 4 * num_samples))

    if num_samples == 1:
        axes = axes.reshape(1, 3)  # Handle single row

    for i in range(num_samples):
        axes[i, 0].imshow(noisy[i].squeeze(0), cmap="gray")
        axes[i, 0].set_title("Noisy Image")
        axes[i, 0].axis("off")

        axes[i, 1].imshow(outputs[i].squeeze(0), cmap="gray")
        axes[i, 1].set_title("Denoised Image")
        axes[i, 1].axis("off")

        axes[i, 2].imshow(clean[i].squeeze(0), cmap="gray")
        axes[i, 2].set_title("Ground Truth")
        axes[i, 2].axis("off")

    plt.tight_layout()
    plt.show()

# Example usage
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
visualize_results(model, test_loader, device)
