In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, ConcatDataset, Subset
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import math

class Noise2NoiseDataset(Dataset):
    def __init__(self, size, S_max=5.0, D=50, nu=0.25, cm=1.0, V=1.0,
                 random_noise_std=0.56, tropospheric_noise_beta=1.82, tropospheric_noise_scale=1.0,
                 total_days=1460, interval_days=49, f_t=-1, orbit_type='ascending'):
        self.size = size
        self.S_max = S_max
        self.D = D
        self.nu = nu
        self.cm = cm
        self.V = V
        self.random_noise_std = random_noise_std
        self.tropospheric_noise_beta = tropospheric_noise_beta
        self.tropospheric_noise_scale = tropospheric_noise_scale
        self.total_days = total_days
        self.interval_days = interval_days
        self.f_t = f_t
        self.orbit_type = orbit_type

        self.incidence_angle_deg, self.satellite_azimuth_deg = self._get_orbit_geometry()
        self.times = self.get_times()
        if len(self.times) == 0:
            print(f"Warning: No time steps generated for total_days={total_days}, interval_days={interval_days}. Defaulting total_time to 1.0.")
            self.total_time = 1.0
        elif len(self.times) == 1:
             self.total_time = self.times[0]
        else:
            self.total_time = self.times[-1]


    def __len__(self):
        return len(self.times)

    def generate_random_noise(self):
        return np.random.normal(loc=0.0, scale=self.random_noise_std, size=self.size)

    def generate_tropospheric_noise(self):
        noise = np.fft.fft2(np.random.randn(*self.size))
        ky = np.fft.fftfreq(self.size[0])
        kx = np.fft.fftfreq(self.size[1])
        kx, ky = np.meshgrid(kx, ky)
        k = np.sqrt(kx**2 + ky**2)
        k[0, 0] = 1e-7
        power = k ** (-self.tropospheric_noise_beta)
        frac_noise = np.fft.ifft2(noise * power).real
        std_val = frac_noise.std()
        if std_val > 1e-6:
            frac_noise = (frac_noise - frac_noise.mean()) / std_val
        else:
            frac_noise = frac_noise - frac_noise.mean()

        return frac_noise * self.tropospheric_noise_scale

    @staticmethod
    def calculate_los_vector(incidence_angle_deg, satellite_azimuth_deg):
        incidence_angle_rad = np.deg2rad(incidence_angle_deg)
        satellite_azimuth_rad = np.deg2rad(satellite_azimuth_deg)
        look_azimuth_rad = satellite_azimuth_rad + np.pi/2

        l_east = np.sin(incidence_angle_rad) * np.sin(look_azimuth_rad)
        l_north = np.sin(incidence_angle_rad) * np.cos(look_azimuth_rad)
        l_up = np.cos(incidence_angle_rad)
        return np.array([l_east, l_north, l_up])


    def generate_subsidence(self, delta_P):
        y, x = np.indices(self.size)
        cx, cy = self.size[1] // 2, self.size[0] // 2
        r_sq = (x - cx)**2 + (y - cy)**2
        r = np.sqrt(r_sq)

        factor = (-1 / np.pi) * self.cm * (1 - self.nu) * delta_P * self.V

        uz = factor * (self.D / ((r**2 + self.D**2)**(1.5)))
        ur = factor * (r / ((r**2 + self.D**2)**(1.5)))

        azimuth = np.arctan2(y - cy, x - cx)
        ux = ur * np.cos(azimuth)
        uy = ur * np.sin(azimuth)

        los_vector_calc = Noise2NoiseDataset.calculate_los_vector(self.incidence_angle_deg, self.satellite_azimuth_deg)

        simulated_interferogram = (ux * los_vector_calc[0]) + \
                                  (uy * los_vector_calc[1]) + \
                                  (uz * los_vector_calc[2])
        return simulated_interferogram

    def get_times(self):
        if self.total_days < self.interval_days or self.interval_days <= 0:
            return np.array([self.total_days if self.total_days > 0 else 1.0])
        return np.arange(self.interval_days, self.total_days + 1, self.interval_days)


    def _get_clean_subsidence_image(self, t):
        delta_P_for_S_max_subsidence = self.S_max * np.pi * self.D**2 / (self.cm * (1 - self.nu) * self.V)

        if self.total_time <= 0 : self.total_time = 1.0

        current_time_factor = 0.0
        if callable(self.f_t):
            if len(self.times) > 0:
                normalized_t = t
                current_time_factor = self.f_t(normalized_t, self.times, self.total_time)
            else:
                current_time_factor = t / self.total_time if self.total_time > 0 else 0
        else:
            current_time_factor = t / self.total_time if self.total_time > 0 else 0

        delta_P_current = delta_P_for_S_max_subsidence * current_time_factor
        return self.generate_subsidence(delta_P=delta_P_current)


    def _get_orbit_geometry(self):
        if self.orbit_type == 'ascending':
            return 40, 15
        elif self.orbit_type == 'descending':
            return 40, 195
        else:
            raise ValueError("Invalid orbit type. Choose 'ascending' or 'descending'.")


    def __getitem__(self, idx):
        current_time = self.times[idx]
        clean_image = self._get_clean_subsidence_image(current_time)

        noise1_random = self.generate_random_noise()
        noise1_tropo = self.generate_tropospheric_noise()
        noisy_image1 = clean_image + noise1_random + noise1_tropo

        noise2_random = self.generate_random_noise()
        noise2_tropo = self.generate_tropospheric_noise()
        noisy_image2 = clean_image + noise2_random + noise2_tropo

        clean_image_tensor = torch.from_numpy(clean_image.copy()).float().unsqueeze(0)
        noisy_image1_tensor = torch.from_numpy(noisy_image1.copy()).float().unsqueeze(0)
        noisy_image2_tensor = torch.from_numpy(noisy_image2.copy()).float().unsqueeze(0)

        return noisy_image1_tensor, noisy_image2_tensor, clean_image_tensor


def f_linear(t, times_array, total_time_val):
    if total_time_val == 0: return 0
    return t / total_time_val

def f_log(t, times_array, total_time_val):
    if total_time_val == 0: return 0
    if t <= 0: t = 1e-6
    if total_time_val <=0: total_time_val = 1e-6
    
    min_time_in_series = times_array[0] if len(times_array)>0 else 1.0

    adjusted_t = t - min_time_in_series + 1
    adjusted_total_time = total_time_val - min_time_in_series + 1

    if adjusted_total_time <= 1:
        return 1.0 if t >= total_time_val else 0.0

    val = np.log1p(max(0, adjusted_t)) / np.log1p(max(1e-7, adjusted_total_time))
    return min(max(0, val), 1.0)


IMG_SIZE = (1500, 1500)
S_MAX = 5
D_DEPTH = 50
RANDOM_NOISE_STD = 0.56
TROPO_SCALE = 1.0
TOTAL_DAYS = 365 * 4
INTERVAL_DAYS = 49

print("Creating datasets...")
dataset_lin_asc = Noise2NoiseDataset(
    size=IMG_SIZE, S_max=S_MAX, D=D_DEPTH, random_noise_std=RANDOM_NOISE_STD,
    tropospheric_noise_scale=TROPO_SCALE, total_days=TOTAL_DAYS, interval_days=INTERVAL_DAYS,
    f_t=-1,
    orbit_type='ascending'
)
print(f"Linear Ascending: {len(dataset_lin_asc)} samples")

dataset_log_asc = Noise2NoiseDataset(
    size=IMG_SIZE, S_max=S_MAX, D=D_DEPTH, random_noise_std=RANDOM_NOISE_STD,
    tropospheric_noise_scale=TROPO_SCALE, total_days=TOTAL_DAYS, interval_days=INTERVAL_DAYS,
    f_t=f_log,
    orbit_type='ascending'
)
print(f"Log Ascending: {len(dataset_log_asc)} samples")

dataset_lin_desc = Noise2NoiseDataset(
    size=IMG_SIZE, S_max=S_MAX, D=D_DEPTH, random_noise_std=RANDOM_NOISE_STD,
    tropospheric_noise_scale=TROPO_SCALE, total_days=TOTAL_DAYS, interval_days=INTERVAL_DAYS,
    f_t=-1,
    orbit_type='descending'
)
print(f"Linear Descending: {len(dataset_lin_desc)} samples")

dataset_log_desc = Noise2NoiseDataset(
    size=IMG_SIZE, S_max=S_MAX, D=D_DEPTH, random_noise_std=RANDOM_NOISE_STD,
    tropospheric_noise_scale=TROPO_SCALE, total_days=TOTAL_DAYS, interval_days=INTERVAL_DAYS,
    f_t=f_log,
    orbit_type='descending'
)
print(f"Log Descending: {len(dataset_log_desc)} samples")

combined_dataset = ConcatDataset([
    dataset_lin_asc, dataset_log_asc,
    dataset_lin_desc, dataset_log_desc
])
print(f"Total combined samples: {len(combined_dataset)}")

total_samples = len(combined_dataset)
indices = list(range(total_samples))
np.random.shuffle(indices)

train_ratio = 0.7
val_ratio = 0.15
test_ratio = 0.15

train_split = int(train_ratio * total_samples)
val_split = int(val_ratio * total_samples)

train_indices = indices[:train_split]
val_indices = indices[train_split : train_split + val_split]
test_indices = indices[train_split + val_split:]

train_dataset = Subset(combined_dataset, train_indices)
val_dataset = Subset(combined_dataset, val_indices)
test_dataset = Subset(combined_dataset, test_indices)

print(f"Train samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Test samples: {len(test_dataset)}")

BATCH_SIZE = 16
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

class SimpleUNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, features=[32, 64, 128]):
        super(SimpleUNet, self).__init__()
        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        for feature in features:
            self.downs.append(self._double_conv(in_channels, feature))
            in_channels = feature

        self.bottleneck = self._double_conv(features[-1], features[-1] * 2)

        for feature in reversed(features):
            self.ups.append(nn.ConvTranspose2d(feature * 2, feature, kernel_size=2, stride=2))
            self.ups.append(self._double_conv(feature * 2, feature))

        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def _double_conv(self, in_c, out_c):
        conv = nn.Sequential(
            nn.Conv2d(in_c, out_c, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_c),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_c, out_c, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_c),
            nn.ReLU(inplace=True),
        )
        return conv

    def forward(self, x):
        skip_connections = []

        for down_conv in self.downs:
            x = down_conv(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]

        for i in range(0, len(self.ups), 2):
            x = self.ups[i](x)
            skip_connection = skip_connections[i//2]

            if x.shape != skip_connection.shape:
                x = torch.nn.functional.interpolate(x, size=skip_connection.shape[2:])

            concat_skip = torch.cat((skip_connection, x), dim=1)
            x = self.ups[i+1](concat_skip)

        return self.final_conv(x)


DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

MODEL_FEATURES = [64, 128, 256, 512]

model = SimpleUNet(in_channels=1, out_channels=1, features=MODEL_FEATURES).to(DEVICE)

LEARNING_RATE = 1e-3
NUM_EPOCHS = 100
WEIGHT_DECAY = 1e-5


criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
scaler = torch.cuda.amp.GradScaler(enabled=(DEVICE.type == 'cuda'))

def train_one_epoch(loader, model, optimizer, criterion, device, scaler):
    model.train()
    epoch_loss = 0
    for batch_idx, (noisy1, noisy2, _) in enumerate(loader):
        noisy1 = noisy1.to(device)
        noisy2 = noisy2.to(device)

        with torch.cuda.amp.autocast(enabled=(DEVICE.type == 'cuda')):
            denoised_output = model(noisy1)
            loss = criterion(denoised_output, noisy2)

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        epoch_loss += loss.item()

        if batch_idx % 50 == 0:
             print(f"Batch {batch_idx}/{len(loader)}, Loss: {loss.item():.4f}")

    avg_epoch_loss = epoch_loss / len(loader)
    print(f"End of Epoch, Avg Training Loss: {avg_epoch_loss:.4f}")
    return avg_epoch_loss

def validate_one_epoch(loader, model, criterion, device):
    model.eval()
    epoch_loss = 0
    with torch.no_grad():
        for noisy1, noisy2, _ in loader:
            noisy1 = noisy1.to(device)
            noisy2 = noisy2.to(device)

            with torch.cuda.amp.autocast(enabled=(DEVICE.type == 'cuda')):
                denoised_output = model(noisy1)
                loss = criterion(denoised_output, noisy2)
            epoch_loss += loss.item()

    avg_epoch_loss = epoch_loss / len(loader)
    print(f"Validation Avg Loss: {avg_epoch_loss:.4f}")
    return avg_epoch_loss


print("\nStarting training...")
train_losses = []
val_losses = []
best_val_loss = float('inf')

for epoch in range(NUM_EPOCHS):
    print(f"\n--- Epoch {epoch+1}/{NUM_EPOCHS} ---")
    train_loss = train_one_epoch(train_loader, model, optimizer, criterion, DEVICE, scaler)
    val_loss = validate_one_epoch(val_loader, model, criterion, DEVICE)

    train_losses.append(train_loss)
    val_losses.append(val_loss)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), "best_denoising_model.pth")
        print(f"Saved new best model with val_loss: {best_val_loss:.4f}")

print("\nTraining finished!")

plt.figure(figsize=(10, 5))
plt.plot(train_losses, label="Training Loss")
plt.plot(val_losses, label="Validation Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss (MSE)")
plt.title("Training and Validation Loss Over Epochs")
plt.legend()
plt.grid(True)
plt.show()


def display_results(model, data_loader, device, num_samples=3):
    model.eval()
    model.to(device)

    data_iter = iter(data_loader)
    fig, axes = plt.subplots(num_samples, 3, figsize=(12, 4 * num_samples))
    if num_samples == 1:
        axes = np.array([axes])


    with torch.no_grad():
        for i in range(num_samples):
            try:
                noisy_input, _, ground_truth = next(data_iter)
            except StopIteration:
                print("Not enough samples in the loader to display the requested number.")
                break

            noisy_input = noisy_input.to(device)
            single_noisy_input = noisy_input[0:1]
            single_ground_truth = ground_truth[0].cpu().numpy().squeeze()

            denoised_output = model(single_noisy_input)
            single_denoised_output = denoised_output[0].cpu().numpy().squeeze()
            single_noisy_input_np = single_noisy_input[0].cpu().numpy().squeeze()

            vmin = min(single_noisy_input_np.min(), single_denoised_output.min(), single_ground_truth.min())
            vmax = max(single_noisy_input_np.max(), single_denoised_output.max(), single_ground_truth.max())

            im = axes[i, 0].imshow(single_noisy_input_np, cmap='viridis', vmin=vmin, vmax=vmax)
            axes[i, 0].set_title(f"Noisy Input (Sample {i+1})")
            axes[i, 0].axis('off')

            im = axes[i, 1].imshow(single_denoised_output, cmap='viridis', vmin=vmin, vmax=vmax)
            axes[i, 1].set_title(f"Denoised Output (Sample {i+1})")
            axes[i, 1].axis('off')

            im = axes[i, 2].imshow(single_ground_truth, cmap='viridis', vmin=vmin, vmax=vmax)
            axes[i, 2].set_title(f"Ground Truth (Sample {i+1})")
            axes[i, 2].axis('off')

    fig.colorbar(im, ax=axes.ravel().tolist(), shrink=0.6, label='LOS Displacement (m)')
    plt.tight_layout()
    plt.show()

if best_val_loss != float('inf'):
    print("\nLoading best model for display...")
    model.load_state_dict(torch.load("best_denoising_model.pth", map_location=DEVICE))

print("\nDisplaying results from Test Set...")
display_results(model, test_loader, DEVICE, num_samples=3)

False
