In [None]:
import os
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
from PIL import Image
import numpy as np
from tqdm import tqdm

In [None]:
# Set the device to GPU if available, otherwise CPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Define paths for datasets and output directories
train_sar_dir = 'dataset/train/sar_tif'
train_gray_dir = 'dataset/train/gray_tif'
val_sar_dir = 'dataset/val/sar_tif'
val_gray_dir = 'dataset/val/gray_tif'
test_sar_dir = 'dataset/test/sar_tif'
test_gray_dir = 'dataset/test/gray_tif'
output_dir = 'output'
model_dir = 'models'
os.makedirs(output_dir, exist_ok=True)
os.makedirs(model_dir, exist_ok=True)

In [None]:
# Define Dataset class for SAR and grayscale images
class SARDataset(Dataset):
    def __init__(self, sar_dir, gray_dir, transform_sar=None, transform_gray=None):
        # Initialize directories and transformations
        self.sar_dir = sar_dir
        self.gray_dir = gray_dir
        self.transform_sar = transform_sar
        self.transform_gray = transform_gray
        # Collect image names from SAR directory
        self.image_names = [f for f in os.listdir(sar_dir) if f.endswith('.tif')]

    def __len__(self):
        # Return the total number of samples
        return len(self.image_names)

    def __getitem__(self, idx):
        # Load paired SAR and grayscale images by replacing '_s1' with '_s2' for grayscale
        sar_file = self.image_names[idx]
        gray_file = sar_file.replace('_s1', '_s2')

        sar_image = Image.open(os.path.join(self.sar_dir, sar_file)).convert("L")
        gray_image = Image.open(os.path.join(self.gray_dir, gray_file)).convert("L")

        # Apply transformations if specified
        if self.transform_sar:
            sar_image = self.transform_sar(sar_image)
        if self.transform_gray:
            gray_image = self.transform_gray(gray_image)

        return sar_image, gray_image

In [None]:
# Define transformations to apply on SAR and grayscale images
transform_sar = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

transform_gray = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

In [None]:
# Define the Despeckling Generator network
class DespecklingGenerator(nn.Module):
    def __init__(self):
        super(DespecklingGenerator, self).__init__()
        # Simple CNN architecture with three convolutional layers and ReLU activations
        self.model = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1),
        )

    def forward(self, x):
        return self.model(x)

In [None]:
# Define the Discriminator network
class Discriminator(nn.Module):
    def __init__(self, in_channels=1):
        super(Discriminator, self).__init__()
        # CNN architecture with LeakyReLU activations for binary classification
        self.model = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 1, kernel_size=4, stride=1, padding=0)
        )

    def forward(self, x):
        return torch.sigmoid(self.model(x))

In [None]:
# Initialize models, optimizers, and loss functions
GD = DespecklingGenerator().to(device)
DD = Discriminator().to(device)
optimizer_GD = optim.Adam(GD.parameters(), lr=0.0001)
optimizer_DD = optim.Adam(DD.parameters(), lr=0.000002)
adversarial_loss = nn.BCELoss()  # Loss function for discriminator
pixel_loss = nn.L1Loss()  # Pixel-wise L1 loss for generator
lambda_adv = 0.05  # Weight for adversarial loss

# Load train, validation, and test datasets
train_dataset = SARDataset(sar_dir=train_sar_dir, gray_dir=train_gray_dir, transform_sar=transform_sar, transform_gray=transform_gray)
val_dataset = SARDataset(sar_dir=val_sar_dir, gray_dir=val_gray_dir, transform_sar=transform_sar, transform_gray=transform_gray)
test_dataset = SARDataset(sar_dir=test_sar_dir, gray_dir=test_gray_dir, transform_sar=transform_sar, transform_gray=transform_gray)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

In [None]:
# Training loop with validation
num_epochs = 5
best_val_psnr = 0
for epoch in range(num_epochs):
    GD.train()
    DD.train()
    epoch_g_loss, epoch_d_loss = 0, 0
    for sar, gray in tqdm(train_loader, desc=f"Epoch [{epoch+1}/{num_epochs}]"):
        sar, gray = sar.to(device), gray.to(device)
        despeckled = GD(sar)

        # Discriminator training
        optimizer_DD.zero_grad()
        real_labels = torch.ones_like(DD(gray)).to(device)
        fake_labels = torch.zeros_like(real_labels).to(device)

        real_loss = adversarial_loss(DD(gray), real_labels)
        fake_loss = adversarial_loss(DD(despeckled.detach()), fake_labels)
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_DD.step()

        # Generator training
        optimizer_GD.zero_grad()
        g_adv_loss = adversarial_loss(DD(despeckled), real_labels)
        g_pix_loss = pixel_loss(despeckled, gray)
        g_loss = g_pix_loss + lambda_adv * g_adv_loss
        g_loss.backward()
        optimizer_GD.step()

        # Accumulate losses for printing
        epoch_g_loss += g_loss.item()
        epoch_d_loss += d_loss.item()

    # Validation PSNR calculation
    GD.eval()
    val_psnr = []
    with torch.no_grad():
        for sar, gray in val_loader:
            sar, gray = sar.to(device), gray.to(device)
            despeckled = GD(sar)
            despeckled_np = despeckled.squeeze().cpu().numpy()
            gray_np = gray.squeeze().cpu().numpy()
            val_psnr.append(psnr(despeckled_np, gray_np, data_range=gray_np.max() - gray_np.min()))
    avg_val_psnr = np.mean(val_psnr)

    # Save best model based on validation PSNR
    if avg_val_psnr > best_val_psnr:
        best_val_psnr = avg_val_psnr
        torch.save({'GD': GD.state_dict()}, f"{model_dir}/best_despeckling_model_1.pth")
        print(f"Saved best model with PSNR: {avg_val_psnr:.4f}")

    # Print epoch metrics
    print(f"Epoch [{epoch+1}/{num_epochs}] - GD Loss: {epoch_g_loss/len(train_loader):.4f}, DD Loss: {epoch_d_loss/len(train_loader):.4f}, Val PSNR: {avg_val_psnr:.4f}")

In [None]:
# Testing and evaluation on test set
GD.eval()
psnr_values, ssim_values = [], []
with torch.no_grad():
    for sar, gray in test_loader:
        sar, gray = sar.to(device), gray.to(device)
        despeckled = GD(sar)
        despeckled_np = despeckled.squeeze().cpu().numpy()
        gray_np = gray.squeeze().cpu().numpy()
        psnr_values.append(psnr(despeckled_np, gray_np, data_range=gray_np.max() - gray_np.min()))
        ssim_values.append(ssim(despeckled_np, gray_np, data_range=gray_np.max() - gray_np.min()))

# Calculate and print average PSNR and SSIM
avg_psnr = np.mean(psnr_values)
avg_ssim = np.mean(ssim_values)
print(f"Average PSNR on test set: {avg_psnr:.4f}")
print(f"Average SSIM on test set: {avg_ssim:.4f}")

Average PSNR on test set: 12.6496
Average SSIM on test set: 0.0536
