Load the dataset

In [None]:
from google.colab import drive
drive.mount('/content/drive')


Define the Autoencoder Model



In [None]:
import torch
import torch.nn as nn

# Check if a GPU is available and if not, fall back to CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define the attention block
class AttentionBlock(nn.Module):
    def __init__(self, in_channels):
        super(AttentionBlock, self).__init__()
        self.attention = nn.Sequential(
            nn.Conv2d(in_channels, in_channels // 8, kernel_size=1),  # Reduce dimensionality
            nn.ReLU(True),
            nn.Conv2d(in_channels // 8, in_channels, kernel_size=1),  # Restore to original channels
            nn.Sigmoid()  # Output attention map with values between 0 and 1
        )

    def forward(self, x):
        attention_weights = self.attention(x)
        return x * attention_weights  # Apply attention weights to the input feature map


# Define the U-Net style autoencoder with attention
class UNetStyleAutoencoderWithAttention(nn.Module):
    def __init__(self):
        super(UNetStyleAutoencoderWithAttention, self).__init__()

        # Encoder
        self.encoder1 = nn.Sequential(
            nn.Conv2d(1, 64, 4, stride=2, padding=1),  # (batch, 64, 256, 256)
            nn.ReLU(True),
            nn.BatchNorm2d(64),
            nn.Dropout(0.3)  # Add dropout with a probability of 30%
        )

        self.encoder2 = nn.Sequential(
            nn.Conv2d(64, 128, 4, stride=2, padding=1),  # (batch, 128, 128, 128)
            nn.ReLU(True),
            nn.Dropout(0.3)  # Add dropout
        )

        # Attention blocks after each encoding layer
        self.attention1 = AttentionBlock(64)
        self.attention2 = AttentionBlock(128)

        # Decoder
        self.decoder1 = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),   # (batch, 64, 256, 256)
            nn.ReLU(True),
            nn.BatchNorm2d(64),
            nn.Dropout(0.3)  # Add dropout
        )

        self.decoder2 = nn.Sequential(
            nn.ConvTranspose2d(128, 1, 4, stride=2, padding=1),    # (batch, 1, 512, 512)
            nn.ReLU(True)
        )

    def forward(self, x):
        # Encoder
        e1 = self.encoder1(x)   # Output from the first encoder block (batch, 64, 256, 256)
        e1 = self.attention1(e1)  # Apply attention block to e1
        e2 = self.encoder2(e1)  # Output from the second encoder block (batch, 128, 128, 128)
        e2 = self.attention2(e2)  # Apply attention block to e2

        # Decoder
        d1 = self.decoder1(e2)  # First decoding step (upsampled to batch, 64, 256, 256)
        # Concatenate skip connection from e1 (batch, 64, 256, 256)
        d1 = torch.cat([d1, e1], dim=1)  # Skip connection with encoder1 output

        d2 = self.decoder2(d1)  # Second decoding step (upsampled to batch, 1, 512, 512)

        return d2


# Instantiate the model and move it to the GPU if available
model = UNetStyleAutoencoderWithAttention().to(device)

# Example input tensor (batch size of 1, 1 channel, 512x512 image)
input_tensor = torch.randn(1, 1, 512, 512).to(device)  # Move input to GPU

# Forward pass through the model
output = model(input_tensor)

print(output.shape)  # Ensure the output is the expected size


torch.Size([1, 1, 512, 512])


Load dataset and merge

In [None]:
import os
import numpy as np

# Define the function to load datasets
def load_all_datasets(base_path, max_datasets=10):
    high_dose_imgs = []
    low_dose_imgs = []

    # Counter to track how many datasets have been loaded
    dataset_count = 0

    for file_name in os.listdir(base_path):
        if file_name.endswith('_dataset.npz'):
            file_path = os.path.join(base_path, file_name)
            print(f"Loading dataset: {file_name}")

            try:
                with np.load(file_path) as data:
                    high_dose_imgs.append(data['high_dose_img'])
                    low_dose_imgs.append(data['low_dose_img'])

                    # Increment the count of datasets loaded
                    dataset_count += 1

                    # Stop loading if we've reached the maximum number of datasets
                    if dataset_count >= max_datasets:
                        break
            except Exception as e:
                print(f"Error loading dataset {file_name}: {e}")

    # Concatenate all images into a global array
    high_dose_imgs = np.concatenate(high_dose_imgs, axis=0)
    low_dose_imgs = np.concatenate(low_dose_imgs, axis=0)

    return high_dose_imgs, low_dose_imgs

Prepare data for training

In [None]:
import numpy as np
import torch
from sklearn.model_selection import train_test_split
from torch.utils.data import TensorDataset, DataLoader

# Function to prepare data
def prepare_data(high_dose_imgs, low_dose_imgs, test_size=0.2, batch_size=16):
    # Normalize images to [0, 1] range and convert to float16 to reduce memory usage
    high_dose_imgs = high_dose_imgs.astype(np.float32) / 255.0
    low_dose_imgs = low_dose_imgs.astype(np.float32) / 255.0

    # Convert to PyTorch tensors and add channel dimension (N, 1, 512, 512)
    high_dose_tensors = torch.from_numpy(high_dose_imgs).unsqueeze(1)
    low_dose_tensors = torch.from_numpy(low_dose_imgs).unsqueeze(1)

    # Delete the numpy arrays to free memory
    del high_dose_imgs
    del low_dose_imgs
    torch.cuda.empty_cache()  # Clear the CUDA memory cache if using GPU

    # Split data into training and testing sets
    train_high, test_high, train_low, test_low = train_test_split(
        high_dose_tensors, low_dose_tensors, test_size=test_size, random_state=42
    )

    # Clear memory by deleting large tensors after splitting
    del high_dose_tensors
    del low_dose_tensors
    torch.cuda.empty_cache()  # Clear the CUDA memory cache if using GPU

    # Create PyTorch datasets
    train_dataset = TensorDataset(train_high, train_low)
    test_dataset = TensorDataset(test_high, test_low)

    # Create DataLoader for batching with pinned memory for faster transfers if using GPU
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=torch.cuda.is_available())
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, pin_memory=torch.cuda.is_available())

    return train_loader, test_loader


Define Loss Function

In [None]:
import torch
import torch.nn.functional as F

def ssim_loss(predicted, target, window_size=11, size_average=True):
    C1 = 0.01 ** 2
    C2 = 0.03 ** 2

    mu_predicted = F.avg_pool2d(predicted, window_size, 1)
    mu_target = F.avg_pool2d(target, window_size, 1)

    mu_predicted_sq = mu_predicted.pow(2)
    mu_target_sq = mu_target.pow(2)
    mu_pred_target = mu_predicted * mu_target

    sigma_predicted = F.avg_pool2d(predicted * predicted, window_size, 1) - mu_predicted_sq
    sigma_target = F.avg_pool2d(target * target, window_size, 1) - mu_target_sq
    sigma_pred_target = F.avg_pool2d(predicted * target, window_size, 1) - mu_pred_target

    ssim_numerator = (2 * mu_pred_target + C1) * (2 * sigma_pred_target + C2)
    ssim_denominator = (mu_predicted_sq + mu_target_sq + C1) * (sigma_predicted + sigma_target + C2)

    ssim_map = ssim_numerator / ssim_denominator
    if size_average:
        return torch.clamp((1 - ssim_map.mean()) / 2, 0, 1)
    else:
        return torch.clamp((1 - ssim_map) / 2, 0, 1)

def psnr(predicted, target, max_pixel_value=1.0):
    """
    Calculates the Peak Signal-to-Noise Ratio (PSNR) between predicted and target images.

    Args:
    predicted (torch.Tensor): The predicted (reconstructed) image.
    target (torch.Tensor): The ground truth (low-dose) image.
    max_pixel_value (float): The maximum possible pixel value in the image (default is 1.0 for normalized images).

    Returns:
    float: The PSNR value.
    """
    mse = F.mse_loss(predicted, target)
    psnr_value = 10 * torch.log10(max_pixel_value ** 2 / mse)
    return psnr_value

# Example usage within a training loop
predicted = torch.rand(1, 1, 256, 256)  # Example reconstructed image
target = torch.rand(1, 1, 256, 256)     # Example real low-dose image

psnr_value = psnr(predicted, target)
print(f"PSNR: {psnr_value:.2f} dB")

class CombinedLoss(nn.Module):
    def __init__(self, alpha=0.84):
        super(CombinedLoss, self).__init__()
        self.alpha = alpha  # Weight between SSIM and MAE loss

    def forward(self, predicted, target):
        # SSIM Loss
        ssim_l = ssim_loss(predicted, target)

        # MAE Loss
        mae_l = F.mse_loss(predicted, target)

        # Combined Loss
        loss = self.alpha * ssim_l + (1 - self.alpha) * mae_l
        return loss


class CombinedLossWithIntensity(nn.Module):
    def __init__(self, alpha=0.84, beta=0.1):  # beta for intensity
        super(CombinedLossWithIntensity, self).__init__()
        self.alpha = alpha  # Weight between SSIM and MAE loss
        self.beta = beta  # Weight for intensity loss

    def forward(self, predicted, target):
        # SSIM Loss
        ssim_l = ssim_loss(predicted, target)

        # MAE Loss
        mae_l = F.l1_loss(predicted, target)

        # Intensity Loss (e.g., mean pixel intensity difference)
        intensity_l = torch.abs(predicted.mean() - target.mean())

        # Combined Loss
        loss = self.alpha * ssim_l + (1 - self.alpha) * mae_l + self.beta * intensity_l
        return loss


PSNR: 7.77 dB


Visualize the results

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
from torch.utils.data import DataLoader

def calculate_psnr(img1, img2):
    """Calculate PSNR between two images."""
    return peak_signal_noise_ratio(img1, img2, data_range=img2.max() - img2.min())

def calculate_ssim(img1, img2):
    """Calculate SSIM between two images."""
    return structural_similarity(img1, img2, data_range=img2.max() - img2.min())

def evaluate_metrics_on_dataloader(model, dataloader):
    model.eval()

    psnr_total = 0.0
    ssim_total = 0.0
    num_samples = len(dataloader)

    with torch.no_grad():
        for high_dose, low_dose in dataloader:
            high_dose = high_dose.to(device)
            low_dose = low_dose.to(device)
            output = model(high_dose)

            # Convert to NumPy arrays for PSNR and SSIM calculations
            low_dose_np = low_dose.squeeze().cpu().numpy()
            reconstructed_np = output.squeeze().cpu().numpy()

            # Calculate PSNR and SSIM
            psnr_value = calculate_psnr(low_dose_np, reconstructed_np)
            ssim_value = calculate_ssim(low_dose_np, reconstructed_np)

            psnr_total += psnr_value
            ssim_total += ssim_value

    # Calculate average PSNR and SSIM
    avg_psnr = psnr_total / num_samples
    avg_ssim = ssim_total / num_samples

    return avg_psnr, avg_ssim

from skimage.exposure import match_histograms

def apply_histogram_matching(reconstructed_image, reference_image):
    """
    Apply histogram matching to adjust the reconstructed image to match the intensity
    distribution of the reference low-dose image.
    """
    # Perform histogram matching
    matched_image = match_histograms(reconstructed_image, reference_image, channel_axis=None)
    return matched_image

def visualize_results_with_histograms(model, test_loader, save_path='output_figure.png'):
    # Evaluate metrics on the entire dataset (before histogram matching)
    avg_psnr, avg_ssim = evaluate_metrics_on_dataloader(model, test_loader)
    print(f'Average PSNR: {avg_psnr:.2f} dB, Average SSIM: {avg_ssim:.4f}')

    # Randomly select a pair of images for visualization
    model.eval()

    with torch.no_grad():
        for high_dose, low_dose in test_loader:
            high_dose = high_dose.to(device)
            low_dose = low_dose.to(device)
            output = model(high_dose)
            break  # Only take one example for visualization

    # Convert to NumPy arrays for visualization
    high_dose_np = high_dose.squeeze().cpu().numpy()
    low_dose_np = low_dose.squeeze().cpu().numpy()
    reconstructed_np = output.squeeze().cpu().numpy()

    # Apply histogram matching to the reconstructed image
    matched_reconstructed_np = apply_histogram_matching(reconstructed_np, low_dose_np)

    # Now create a new dataset for the histogram-matched images
    class MatchedDataset(torch.utils.data.Dataset):
        def __init__(self, matched_images, high_dose_images):
            self.matched_images = matched_images
            self.high_dose_images = high_dose_images

        def __len__(self):
            return len(self.matched_images)

        def __getitem__(self, idx):
            # Convert to 4D tensors: (1, 1, H, W) to match (batch_size, channels, height, width)
            matched_image = torch.tensor(self.matched_images[idx]).unsqueeze(0)
            high_dose_image = torch.tensor(self.high_dose_images[idx]).unsqueeze(0)
            return matched_image, high_dose_image

    # Create a dataset with the single pair of matched reconstructed and high-dose image
    matched_dataset = MatchedDataset([matched_reconstructed_np], [high_dose_np])

    # Create a DataLoader for the matched dataset
    matched_loader = DataLoader(matched_dataset, batch_size=1, shuffle=False)

    # Evaluate PSNR and SSIM for the histogram-matched dataset
    matched_psnr, matched_ssim = evaluate_metrics_on_dataloader(model, matched_loader)
    print(f'After Histogram Matching - PSNR: {matched_psnr:.2f} dB, SSIM: {matched_ssim:.4f}')

    # Plot original low-dose, high-dose, reconstructed images, and the matched image
    fig, axes = plt.subplots(1, 4, figsize=(18, 6))  # Single row for images

    # Display images
    axes[0].imshow(high_dose_np, cmap='gray')
    axes[0].set_title('Original High-Dose Image')
    axes[0].axis('off')

    axes[1].imshow(low_dose_np, cmap='gray')
    axes[1].set_title('Original Low-Dose Image')
    axes[1].axis('off')

    axes[2].imshow(reconstructed_np, cmap='gray')
    axes[2].set_title(f'Reconstructed Low-Dose Image\nPSNR: {avg_psnr:.2f} dB, SSIM: {avg_ssim:.4f}')
    axes[2].axis('off')

    axes[3].imshow(matched_reconstructed_np, cmap='gray')
    axes[3].set_title(f'Histogram-Matched Image\nPSNR: {matched_psnr:.2f} dB, SSIM: {matched_ssim:.4f}')
    axes[3].axis('off')

    plt.tight_layout()

    # Save figure as PNG
    plt.savefig(save_path, format='png')

    # Show the figure
    plt.show()


One Plot randomly

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from skimage.metrics import peak_signal_noise_ratio as psnr_metric, structural_similarity as ssim_metric
from torch.utils.data import DataLoader
from skimage.exposure import match_histograms
import random

def calculate_psnr(img1, img2):
    """Calculate PSNR between two images."""
    return psnr_metric(img1, img2, data_range=img2.max() - img2.min())

def calculate_ssim(img1, img2):
    """Calculate SSIM between two images."""
    return ssim_metric(img1, img2, data_range=img2.max() - img2.min())

def apply_histogram_matching(reconstructed_image, reference_image):
    """
    Apply histogram matching to adjust the reconstructed image to match the intensity
    distribution of the reference low-dose image.
    """
    # Perform histogram matching
    matched_image = match_histograms(reconstructed_image, reference_image, channel_axis=None)
    return matched_image

def visualize_single_pair_with_histograms(model, test_loader, save_path='output_figure_single.png'):
    # Randomly select a pair of images for visualization
    model.eval()

    with torch.no_grad():
        for high_dose, low_dose in test_loader:
            high_dose = high_dose.to(device)
            low_dose = low_dose.to(device)
            output = model(low_dose)

            # Randomly select an index from the current batch
            random_idx = random.randint(0, high_dose.size(0) - 1)


            break  # Only take one example for visualization

    # Select only the first image from the batch
    high_dose_np = high_dose[random_idx].squeeze().cpu().numpy()
    low_dose_np = low_dose[random_idx].squeeze().cpu().numpy()
    reconstructed_np = output[random_idx].squeeze().cpu().numpy()

    # Calculate PSNR and SSIM for the original reconstructed image
    psnr_value = calculate_psnr(low_dose_np, reconstructed_np)
    ssim_value = calculate_ssim(low_dose_np, reconstructed_np)

    print(f'Before Histogram Matching - PSNR: {psnr_value:.2f} dB, SSIM: {ssim_value:.4f}')


    # Plot original low-dose, high-dose, reconstructed images, and the matched image
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))  # Single row for images

    # Display images
    axes[0].imshow(high_dose_np, cmap='gray')
    axes[0].set_title('Original High-Dose Image')
    axes[0].axis('off')

    axes[1].imshow(low_dose_np, cmap='gray')
    axes[1].set_title('Original Low-Dose Image')
    axes[1].axis('off')

    axes[2].imshow(reconstructed_np, cmap='gray')
    axes[2].set_title(f'Reconstructed Low-Dose Image\nPSNR: {psnr_value:.2f} dB, SSIM: {ssim_value:.4f}')
    axes[2].axis('off')


    plt.tight_layout()

    # Save figure as PNG
    plt.savefig(save_path, format='png')

    # Show the figure
    plt.show()


Main

In [None]:
# Set the base path to the sample_data directory in Colab
if __name__ == "__main__":
    base_path = '/content/drive/My Drive/Thesis Master/dataset' # Path where the .npz files are stored

    # Load and prepare data
    high_dose_imgs, low_dose_imgs = load_all_datasets(base_path, max_datasets=49)

    # Check the shape of the loaded data
    print(f"High Dose Images Shape: {high_dose_imgs.shape}")
    print(f"Low Dose Images Shape: {low_dose_imgs.shape}")

    # Assuming high_dose_imgs and low_dose_imgs are already loaded
    #train_loader, test_loader = prepare_data(high_dose_imgs, low_dose_imgs)

Loading dataset: C002_dataset.npz
Loading dataset: C004_dataset.npz
Loading dataset: C012_dataset.npz
Loading dataset: C016_dataset.npz
Loading dataset: C021_dataset.npz
Loading dataset: C027_dataset.npz
Loading dataset: C030_dataset.npz
Loading dataset: C050_dataset.npz
Loading dataset: C052_dataset.npz
Loading dataset: C067_dataset.npz
Loading dataset: C077_dataset.npz
Loading dataset: C081_dataset.npz
Loading dataset: C095_dataset.npz
Loading dataset: C099_dataset.npz
Loading dataset: C107_dataset.npz
Loading dataset: C111_dataset.npz
Loading dataset: C120_dataset.npz
Loading dataset: C121_dataset.npz
Loading dataset: C124_dataset.npz
Loading dataset: C128_dataset.npz
Loading dataset: C130_dataset.npz
Loading dataset: C135_dataset.npz
Loading dataset: C158_dataset.npz
Loading dataset: C160_dataset.npz
Loading dataset: C162_dataset.npz
Loading dataset: C166_dataset.npz
Loading dataset: C170_dataset.npz
Loading dataset: C179_dataset.npz
Loading dataset: C190_dataset.npz
Loading datase

In [None]:
train_loader, test_loader = prepare_data(high_dose_imgs, low_dose_imgs)

Training

In [None]:
import torch
import torch.optim as optim
import pytorch_ssim

# Initialize the model, loss function, and optimizer
model = UNetStyleAutoencoderWithAttention().to(device)  # Move model to GPU
# criterion = nn.L1Loss()  # If using L1 loss as an alternative
loss_fn = CombinedLossWithIntensity(alpha=0.7, beta=0.1)  # Assuming this is your custom loss function
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)

# Training loop
num_epochs = 4

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for high_dose, low_dose in train_loader:
        # Move the data to the GPU
        high_dose = high_dose.to(device)
        low_dose = low_dose.to(device)

        optimizer.zero_grad()  # Clear gradients

        # Forward pass
        outputs = model(high_dose)

        # Calculate loss
        loss = loss_fn(outputs, low_dose)

        # Backward pass and optimization
        loss.backward()

        # Gradient Clipping to avoid exploding gradients
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()  # Update model parameters

        running_loss += loss.item()  # Accumulate loss

    # Print the average loss for the epoch
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss / len(train_loader)}')



KeyboardInterrupt: 

In [None]:
visualize_single_pair_with_histograms(model, test_loader)

NameError: name 'visualize_single_pair_with_histograms' is not defined

! pip install pytorch_ssim

In [None]:
! pip install pytorch_ssim

Collecting pytorch_ssim
  Downloading pytorch_ssim-0.1.tar.gz (1.4 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: pytorch_ssim
  Building wheel for pytorch_ssim (setup.py) ... [?25l[?25hdone
  Created wheel for pytorch_ssim: filename=pytorch_ssim-0.1-py3-none-any.whl size=2007 sha256=e4f8faef4b55e07f7e0d7e99c6aeebb3a2de01a7b8bff3b328b39e412b6dcda6
  Stored in directory: /root/.cache/pip/wheels/2e/0c/10/4a3f91bd610b23196f1e28f8af80b3ec86786b50f3e86dc21e
Successfully built pytorch_ssim
Installing collected packages: pytorch_ssim
Successfully installed pytorch_ssim-0.1
