In [None]:
import torch
from os import path
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as transforms

import h5py
import numpy as np
import wget
import matplotlib.pyplot as plt
from skimage.metrics import peak_signal_noise_ratio as psnr, structural_similarity as ssim
import csv
from torchvision.utils import save_image

import os
import glob
from PIL import Image
import numpy as np
import torch
import csv

import lpips




In [2]:
##Check if data path is correct
data_path = "/home/sulaimon/EXPERIMENT/23/images_001"
image_files = glob.glob(os.path.join(data_path, '*.png'))
print(f'Found {len(image_files)} Images.')

Found 14999 Images.


In [None]:
#=====Load Data=====
def load_xray_image(data_path):
    img = Image.open(data_path).convert('L')  # Ensure grayscale
    img = img.resize((256, 256))              # Resize to 256x256
    img_array = np.array(img, dtype=np.float32) / 255.0  # Normalize to [0,1]
    img_array = np.expand_dims(img_array, axis=0)  # Add channel dimension: (1, 256, 256)
    return img_array

# Load all images into a numpy array
InputImages = np.array([load_xray_image(f) for f in image_files])
print("Dataset Shape:", InputImages.shape)

Dataset Shape: (14999, 1, 256, 256)


In [None]:
#=====Add Noise=====
def add_gaussian_noise(images, mean=0.1, stddev=0.1):
    noise = torch.normal( mean, std=stddev, size=images.shape)
    noisy_images = images + noise
    return torch.clamp(noisy_images, 0., 1.)

InputImages = torch.tensor(InputImages, dtype=torch.float32)
noisy_images = add_gaussian_noise(InputImages)


In [None]:
#=====Data Loader=====
class X_rayDataset(Dataset):
    def __init__(self, noisy, clean):
        self.noisy = noisy
        self.clean = clean

    def __len__(self):
        return len(self.noisy)
    
    def __getitem__(self, idx):
        return self.noisy[idx], self.clean[idx]


##Splitting and Data Loader Creation
dataset = X_rayDataset(noisy_images, InputImages)

#Ensuring reproducibility
torch.manual_seed(42)
np.random.seed(42)

train_size = int(0.5 * len(dataset))
val_size = int(0.33 * len(dataset))
test_size = len(dataset) - train_size - val_size

train_set, val_set, test_set = random_split(dataset, [train_size, val_size, test_size])

train_loader = DataLoader(train_set, batch_size=32, shuffle=True)
val_loader = DataLoader(val_set, batch_size=32)
test_loader = DataLoader(test_set, batch_size=32, shuffle=False)

##Verify the samples
def print_dataset_shapes(loader, name):
    data_iter = iter(loader)
    noisy_sample, clean_sample = next(data_iter)
    print(f"{name} set: Noisy shape: {noisy_sample.shape}, Clean shape: {clean_sample.shape}")

print_dataset_shapes(train_loader, "Train")
print_dataset_shapes(val_loader, "Validation")
print_dataset_shapes(test_loader, "Test")

print(f"Train set size: {len(train_set)}")
print(f"Val Set: {len(val_set)}")
print(f"Test Set: {len(test_set)}")

Train set: Noisy shape: torch.Size([32, 1, 256, 256]), Clean shape: torch.Size([32, 1, 256, 256])
Validation set: Noisy shape: torch.Size([32, 1, 256, 256]), Clean shape: torch.Size([32, 1, 256, 256])
Test set: Noisy shape: torch.Size([32, 1, 256, 256]), Clean shape: torch.Size([32, 1, 256, 256])
Train set size: 7499
Val Set: 4949
Test Set: 2551


In [None]:
#=====UNet Model Definition=====
class UNet(nn.Module):
    def __init__(self, input_channels=1, Nc=64):
        super(UNet, self).__init__()

        # Contracting Path
        self.conv1 = self.conv_block(input_channels, Nc)
        self.pool1 = nn.MaxPool2d(2)

        self.conv2 = self.conv_block(Nc, Nc * 2)
        self.pool2 = nn.MaxPool2d(2)

        self.conv3 = self.conv_block(Nc * 2, Nc * 4)
        self.pool3 = nn.MaxPool2d(2)

        self.conv4 = self.conv_block(Nc * 4, Nc * 8)
        self.pool4 = nn.MaxPool2d(2)

        self.conv5 = self.conv_block(Nc * 8, Nc * 16)

        # Expanding Path
        self.upconv6 = nn.ConvTranspose2d(Nc * 16, Nc * 8, kernel_size=2, stride=2)
        self.conv6 = self.conv_block(Nc * 16, Nc * 8)

        self.upconv7 = nn.ConvTranspose2d(Nc * 8, Nc * 4, kernel_size=2, stride=2)
        self.conv7 = self.conv_block(Nc * 8, Nc * 4)

        self.upconv8 = nn.ConvTranspose2d(Nc * 4, Nc * 2, kernel_size=2, stride=2)
        self.conv8 = self.conv_block(Nc * 4, Nc * 2)

        self.upconv9 = nn.ConvTranspose2d(Nc * 2, Nc, kernel_size=2, stride=2)
        self.conv9 = self.conv_block(Nc * 2, Nc)

        # Final output for grayscale (only 1 channel)
        self.final_conv = nn.Conv2d(Nc, 1, kernel_size=1)

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        outputs = {}

        # Contracting Path
        outputs["conv1"] = self.conv1(x)
        pool1 = self.pool1(outputs["conv1"])

        outputs["conv2"] = self.conv2(pool1)
        pool2 = self.pool2(outputs["conv2"])

        outputs["conv3"] = self.conv3(pool2)
        pool3 = self.pool3(outputs["conv3"])

        outputs["conv4"] = self.conv4(pool3)
        pool4 = self.pool4(outputs["conv4"])

        outputs["conv5"] = self.conv5(pool4)  # Bottleneck

        # Expanding Path
        up6 = self.upconv6(outputs["conv5"])
        merge6 = torch.cat((outputs["conv4"], up6), dim=1)
        outputs["conv6"] = self.conv6(merge6)

        up7 = self.upconv7(outputs["conv6"])
        merge7 = torch.cat((outputs["conv3"], up7), dim=1)
        outputs["conv7"] = self.conv7(merge7)

        up8 = self.upconv8(outputs["conv7"])
        merge8 = torch.cat((outputs["conv2"], up8), dim=1)
        outputs["conv8"] = self.conv8(merge8)

        up9 = self.upconv9(outputs["conv8"])
        merge9 = torch.cat((outputs["conv1"], up9), dim=1)
        outputs["conv9"] = self.conv9(merge9)

        outputs["final"] = self.final_conv(outputs["conv9"])

        return outputs  # Return all intermediate outputs

# Example usage with DataParallel
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet(input_channels=1, Nc=64)

# Wrap with DataParallel if more than one GPU
if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs with DataParallel")
    model = nn.DataParallel(model)

model.to(device)

Using 2 GPUs with DataParallel


DataParallel(
  (module): UNet(
    (conv1): Sequential(
      (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU(inplace=True)
      (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (4): ReLU(inplace=True)
    )
    (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (conv2): Sequential(
      (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU(inplace=True)
      (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (4): ReLU(inplace=True)
    )
    (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (conv3): Sequential(
      (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): Re

In [None]:
#=====Metrics=====
def psnr_metric(y_true, y_pred):
    mse = F.mse_loss(y_pred, y_true)
    return 10 * torch.log10(1.0 / mse)

def ssim_metric(y_true, y_pred):
    y_true = y_true.detach().cpu().numpy()
    y_pred = y_pred.detach().cpu().numpy()
    ssim_scores = []
    for i in range(y_true.shape[0]):  # loop over batch
        ssim_score = ssim(y_true[i,0], y_pred[i,0], data_range=1.0)  # Only first (and only) channel
        ssim_scores.append(ssim_score)
    return torch.tensor(ssim_scores).mean()


optimizer = optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.L1Loss()


In [None]:
#=====Training=====
def train_model(model, train_loader, val_loader, epochs=50, save_path="best_model_unet_DP_10.pth"):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    model.to(device)  # Move model to GPU if available

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    loss_fn = nn.L1Loss()  # Use L1Loss as defined earlier

    best_loss = float('inf')

    for epoch in range(epochs):
        model.train()
        train_loss = 0.0

        for noisy, clean in train_loader:
            noisy, clean = noisy.to(device), clean.to(device)

            optimizer.zero_grad()
            outputs = model(noisy)["final"]
            loss = loss_fn(outputs, clean)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

        val_loss = 0.0
        model.eval()
        with torch.no_grad():
            for noisy, clean in val_loader:
                noisy, clean = noisy.to(device), clean.to(device)

                outputs = model(noisy)["final"]
                loss = loss_fn(outputs, clean)
                val_loss += loss.item()

        avg_train_loss = train_loss / len(train_loader)
        avg_val_loss = val_loss / len(val_loader)

        print(f"Epoch {epoch+1}/{epochs} => Train Loss: {avg_train_loss:.6f}, Validation Loss: {avg_val_loss:.6f}")

        # Save best model
        if avg_val_loss < best_loss:
            best_loss = avg_val_loss
            torch.save(model.state_dict(), save_path)
            print(f"Model saved at epoch {epoch+1} with best validation loss: {best_loss:.6f}")

# Example: Call training
train_model(model, train_loader, val_loader)


Using device: cuda
Epoch 1/50 => Train Loss: 0.044421, Validation Loss: 0.047310
Model saved at epoch 1 with best validation loss: 0.047310
Epoch 2/50 => Train Loss: 0.031486, Validation Loss: 0.025958
Model saved at epoch 2 with best validation loss: 0.025958
Epoch 3/50 => Train Loss: 0.028396, Validation Loss: 0.042595
Epoch 4/50 => Train Loss: 0.026298, Validation Loss: 0.030584
Epoch 5/50 => Train Loss: 0.025326, Validation Loss: 0.025915
Model saved at epoch 5 with best validation loss: 0.025915
Epoch 6/50 => Train Loss: 0.024750, Validation Loss: 0.039112
Epoch 7/50 => Train Loss: 0.022577, Validation Loss: 0.023523
Model saved at epoch 7 with best validation loss: 0.023523
Epoch 8/50 => Train Loss: 0.023542, Validation Loss: 0.035316
Epoch 9/50 => Train Loss: 0.021107, Validation Loss: 0.026951
Epoch 10/50 => Train Loss: 0.022833, Validation Loss: 0.020273
Model saved at epoch 10 with best validation loss: 0.020273
Epoch 11/50 => Train Loss: 0.021419, Validation Loss: 0.021530
E

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Initialize LPIPS model globally )
loss_fn_lpips = lpips.LPIPS(net='alex').to(device)

#=====Model Testing=====
def test_model(model, test_loader, device):
    model.eval()
    model.to(device)

    total_psnr = 0.0
    total_ssim = 0.0
    total_lpips = 0.0
    num_samples = 0

    with torch.no_grad():
        for noisy, clean in test_loader:
            noisy, clean = noisy.to(device), clean.to(device)
            outputs = model(noisy)["final"]

            for i in range(noisy.shape[0]):
                y_true = clean[i,0].cpu().numpy()   # (H,W) for grayscale
                y_pred = outputs[i,0].cpu().numpy() # (H,W) for grayscale

                total_psnr += psnr(y_true, y_pred, data_range=1.0)
                total_ssim += ssim(y_true, y_pred, data_range=1.0, win_size=3)  # multichannel=False by default

                # LPIPS: convert single-channel to 3-channel and normalize to [-1, 1]
                out_img = outputs[i].unsqueeze(0)  # [1, 1, H, W]
                tgt_img = clean[i].unsqueeze(0)    # [1, 1, H, W]

                if out_img.shape[1] == 1:
                    out_img = out_img.repeat(1, 3, 1, 1)
                    tgt_img = tgt_img.repeat(1, 3, 1, 1)

                out_img = (out_img * 2) - 1
                tgt_img = (tgt_img * 2) - 1

                lpips_score = loss_fn_lpips(out_img.to(device), tgt_img.to(device))
                total_lpips += lpips_score.item()

                num_samples += 1

    avg_psnr = total_psnr / num_samples
    avg_ssim = total_ssim / num_samples
    avg_lpips = total_lpips / num_samples

    print(f"Average PSNR: {avg_psnr:.2f} dB")
    print(f"Average SSIM: {avg_ssim:.4f}")
    print(f"Average LPIPS: {avg_lpips:.4f}")
    print(f"Test samples evaluated: {num_samples}")


model.load_state_dict(torch.load('best_model_unet_DP_10.pth', map_location=device))

test_model(model, test_loader, device)

In [None]:
#=====Save Sample Test Data=====
def save_test_image_triplets_from_model(model, test_loader, device="cpu", output_dir="saved_unet_imagesDT"):
    # Ensure output directory exists
    os.makedirs(output_dir, exist_ok=True)

    # Get one batch from the test set
    noisy_imgs, clean_imgs = next(iter(test_loader))
    noisy_imgs = noisy_imgs.to(device)
    clean_imgs = clean_imgs.to(device)

    # Run inference
    model.eval()
    with torch.no_grad():
        output_imgs = model(noisy_imgs)["final"]  # Assuming model returns dict with key "final"

    # Move to CPU for saving
    noisy_imgs = noisy_imgs.cpu()
    output_imgs = output_imgs.cpu()
    clean_imgs = clean_imgs.cpu()

    # Save the first 10 triplets
    for i in range(min(10, len(clean_imgs))):
        save_image(noisy_imgs[i], os.path.join(output_dir, f"img_{i+1:02d}_noisy.png"))
        save_image(output_imgs[i], os.path.join(output_dir, f"img_{i+1:02d}_denoised.png"))
        save_image(clean_imgs[i], os.path.join(output_dir, f"img_{i+1:02d}_clean.png"))

    print(f"Saved {min(10, len(clean_imgs)) * 3} images (triplets) to '{output_dir}'")

save_test_image_triplets_from_model(model, test_loader, device)

In [None]:
#=====Save Metrics to CSV=====
def save_metrics_csv_from_model(model, test_loader, device="cpu", csv_path="unet_DP_10.csv"):
    model.eval()
    model.to(device)

    loss_fn_lpips = lpips.LPIPS(net='alex').to(device)

    psnr_list = []
    ssim_list = []
    lpips_list = []
    count = 0

    with open(csv_path, mode='w', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(["Image Index", "PSNR", "SSIM", "LPIPS"])

        with torch.no_grad():
            for noisy, clean in test_loader:
                noisy = noisy.to(device)
                clean = clean.to(device)

                outputs = model(noisy)["final"]

                clean_np = clean.squeeze(1).cpu().numpy()     # (B, H, W)
                outputs_np = outputs.squeeze(1).cpu().numpy() # (B, H, W)

                for i in range(clean_np.shape[0]):
                    # PSNR & SSIM
                    psnr_val = psnr(clean_np[i], outputs_np[i], data_range=1.0)
                    ssim_val = ssim(clean_np[i], outputs_np[i], data_range=1.0, win_size=3)

                    # LPIPS: prepare 3-channel normalized input
                    out_img = outputs[i].unsqueeze(0)  # [1, 1, H, W]
                    tgt_img = clean[i].unsqueeze(0)

                    if out_img.shape[1] == 1:
                        out_img = out_img.repeat(1, 3, 1, 1)
                        tgt_img = tgt_img.repeat(1, 3, 1, 1)

                    out_img = (out_img * 2) - 1
                    tgt_img = (tgt_img * 2) - 1

                    lpips_val = loss_fn_lpips(out_img.to(device), tgt_img.to(device)).item()

                    psnr_list.append(psnr_val)
                    ssim_list.append(ssim_val)
                    lpips_list.append(lpips_val)

                    writer.writerow([count + 1, psnr_val, ssim_val, lpips_val])
                    count += 1

        writer.writerow([])  # Blank line
        writer.writerow(["Average", np.mean(psnr_list), np.mean(ssim_list), np.mean(lpips_list)])

    print(f" Saved PSNR/SSIM/LPIPS metrics for {count} images to '{csv_path}'")
    
save_metrics_csv_from_model(model, test_loader, device)

Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /home/sulaimon/.local/lib/python3.10/site-packages/lpips/weights/v0.1/alex.pth
 Saved PSNR/SSIM/LPIPS metrics for 2551 images to 'unet_DP_10.csv'


In [None]:
#=====Visualise Clean, Noisy, and Denoised Image Test sets=====
def visualize_results(model, test_loader, device, num_samples=10):
    model.eval()
    model.to(device)

    data_iter = iter(test_loader)
    noisy, clean = next(data_iter)
    noisy, clean = noisy.to(device), clean.to(device)

    with torch.no_grad():
        outputs = model(noisy)["final"]

    # Move tensors to CPU
    noisy = noisy.cpu()
    outputs = outputs.cpu()
    clean = clean.cpu()

    # Cap num_samples if batch smaller
    num_samples = min(num_samples, noisy.shape[0])

    fig, axes = plt.subplots(num_samples, 3, figsize=(12, 4 * num_samples))

    if num_samples == 1:
        axes = axes.reshape(1, 3)  # Handle single row

    for i in range(num_samples):
        axes[i, 0].imshow(noisy[i].squeeze(0), cmap="gray")
        axes[i, 0].set_title("Noisy Image")
        axes[i, 0].axis("off")

        axes[i, 1].imshow(outputs[i].squeeze(0), cmap="gray")
        axes[i, 1].set_title("Denoised Image")
        axes[i, 1].axis("off")

        axes[i, 2].imshow(clean[i].squeeze(0), cmap="gray")
        axes[i, 2].set_title("Ground Truth")
        axes[i, 2].axis("off")

    plt.tight_layout()
    plt.show()

# Example usage
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
visualize_results(model, test_loader, device)
