In [5]:
import os
import torch
import numpy as np
from PIL import Image
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
from scipy.ndimage import variance
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch.nn as nn

# Define the Despeckling Generator network
class DespecklingGenerator(nn.Module):
    def __init__(self):
        super(DespecklingGenerator, self).__init__()
        # Simple CNN architecture with three convolutional layers and ReLU activations
        self.model = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1),
        )

    def forward(self, x):
        return self.model(x)

# Set paths for input, expected, and output directories
input_dir = 'dataset/test/sar_tif'
expected_dir = 'dataset/test/gray_tif'
output_dir = 'output'
os.makedirs(output_dir, exist_ok=True)

# Load the saved model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
GD = DespecklingGenerator().to(device)
checkpoint = torch.load('models/best_despeckling_model_1.pth', map_location=device)
GD.load_state_dict(checkpoint['GD'])
GD.eval()

# Define transformations for SAR and expected images
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Define Dataset class for inference
class InferenceDataset(Dataset):
    def __init__(self, input_dir, expected_dir, transform=None):
        self.input_dir = input_dir
        self.expected_dir = expected_dir
        self.transform = transform
        self.image_names = [f for f in os.listdir(input_dir) if f.endswith('.tif')]

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        input_file = self.image_names[idx]
        expected_file = input_file.replace('_s1', '_s2')

        input_image = Image.open(os.path.join(self.input_dir, input_file)).convert("L")
        expected_image = Image.open(os.path.join(self.expected_dir, expected_file)).convert("L")

        if self.transform:
            input_image = self.transform(input_image)
            expected_image = self.transform(expected_image)

        return input_file, input_image, expected_image

# Load the inference dataset
inference_dataset = InferenceDataset(input_dir=input_dir, expected_dir=expected_dir, transform=transform)
inference_loader = DataLoader(inference_dataset, batch_size=1, shuffle=False)

# Function to calculate SNR
def snr(signal, noise):
    signal_power = np.mean(np.square(signal))
    noise_power = np.mean(np.square(noise))
    return 10 * np.log10(signal_power / noise_power) if noise_power != 0 else float('inf')

# Inference and metrics calculation
psnr_values, ssim_values, snr_values, nvl_values = [], [], [], []

with torch.no_grad():
    for input_file, sar_image, expected_image in inference_loader:
        sar_image = sar_image.to(device)
        despeckled = GD(sar_image)

        # Convert outputs and ground truth to numpy arrays
        despeckled_np = despeckled.squeeze().cpu().numpy()
        expected_np = expected_image.squeeze().cpu().numpy()

        # Save despeckled output
        output_file = os.path.join(output_dir, input_file[0])
        despeckled_img = Image.fromarray((despeckled_np * 255).astype(np.uint8))
        despeckled_img.save(output_file)

        # Calculate metrics
        psnr_val = psnr(despeckled_np, expected_np, data_range=expected_np.max() - expected_np.min())
        ssim_val = ssim(despeckled_np, expected_np, data_range=expected_np.max() - expected_np.min())
        snr_val = snr(expected_np, despeckled_np - expected_np)
        nvl_val = variance(despeckled_np)

        # Append metrics to lists
        psnr_values.append(psnr_val)
        ssim_values.append(ssim_val)
        snr_values.append(snr_val)
        nvl_values.append(nvl_val)

# Calculate average values of metrics
avg_psnr = np.mean(psnr_values)
avg_ssim = np.mean(ssim_values)
avg_snr = np.mean(snr_values)
avg_nvl = np.mean(nvl_values)

# Print the average metrics
print(f"Average PSNR: {avg_psnr:.4f}")
print(f"Average SSIM: {avg_ssim:.4f}")
print(f"Average SNR: {avg_snr:.4f}")
print(f"Average NVL: {avg_nvl:.4f}")


  checkpoint = torch.load('models/best_despeckling_model_1.pth', map_location=device)


Average PSNR: 13.5034
Average SSIM: 0.0717
Average SNR: 0.5446
Average NVL: 0.0095
