# 1. Generating InSAR Dataset

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import math
import os
import pandas as pd

In [None]:
class Noise2NoiseDataset(Dataset):
    def __init__(self, size, S_max=5.0, D=50, nu=0.25, cm=1.0, V=1.0,
                 random_noise_std=0.56, tropospheric_noise_beta=1.82, tropospheric_noise_scale=1.0,
                 total_days=1460, interval_days=49, f_t=-1, orbit_type='ascending'):
        self.size = size
        self.S_max = S_max
        self.D = D
        self.nu = nu
        self.cm = cm
        self.V = V
        self.random_noise_std = random_noise_std
        self.tropospheric_noise_beta = tropospheric_noise_beta
        self.tropospheric_noise_scale = tropospheric_noise_scale
        self.total_days = total_days
        self.interval_days = interval_days
        self.f_t = f_t
        self.orbit_type = orbit_type

        self.incidence_angle_deg, self.satellite_azimuth_deg = self._get_orbit_geometry()
        self.times = self.get_times()
        
        if len(self.times) == 0:
            self.total_time = 1.0
        elif len(self.times) == 1:
             self.total_time = self.times[0] if self.times[0] > 0 else 1.0
        else:
            self.total_time = self.times[-1] if self.times[-1] > 0 else 1.0
        if self.total_time <=0: self.total_time = 1.0

    def _load_dem(self, dem_path):
        print(f"Loading DEM data from {dem_path}", flush=True)
        dem_flat = np.fromfile(dem_path, dtype='>4')
        dem_data = dem_flat.reshape((8102, 8102))
        print(f"DEM data shape: {dem_data.shape}", flush=True)
        return dem_data
    
    def _load_baselines(self, baselines_path):
        print(f"Loading baselines data from {baselines_path}", flush=True)
        df = pd.read_csv(baselines_path, delim_whitespace=True, header=None, usecols=[2])
        print(f"Baselines data shape: {df.shape}", flush=True)
        return df[2].values.tolist()

    def __len__(self):
        return len(self.times)

    def generate_random_noise(self):
        return np.random.normal(loc=0.0, scale=self.random_noise_std, size=self.size)

    def generate_tropospheric_noise(self):
        noise = np.fft.fft2(np.random.randn(*self.size))
        ky = np.fft.fftfreq(self.size[0])
        kx = np.fft.fftfreq(self.size[1])
        kx, ky = np.meshgrid(kx, ky)
        k = np.sqrt(kx**2 + ky**2)
        k[0, 0] = 1e-7
        power = k ** (-self.tropospheric_noise_beta)
        frac_noise = np.fft.ifft2(noise * power).real
        std_val = frac_noise.std()
        if std_val > 1e-9:
            frac_noise = (frac_noise - frac_noise.mean()) / std_val
        else:
            frac_noise = frac_noise - frac_noise.mean()
        return frac_noise * self.tropospheric_noise_scale

    @staticmethod
    def calculate_los_vector(incidence_angle_deg, satellite_azimuth_deg):
        incidence_angle_rad = np.deg2rad(incidence_angle_deg)
        satellite_azimuth_rad = np.deg2rad(satellite_azimuth_deg)
        look_azimuth_rad = satellite_azimuth_rad + np.pi/2
        l_east = np.sin(incidence_angle_rad) * np.sin(look_azimuth_rad)
        l_north = np.sin(incidence_angle_rad) * np.cos(look_azimuth_rad)
        l_up = np.cos(incidence_angle_rad)
        return np.array([l_east, l_north, l_up])

    def generate_subsidence(self, delta_P):
        y, x = np.indices(self.size)
        cx, cy = self.size[1] // 2, self.size[0] // 2
        r_sq = (x - cx)**2 + (y - cy)**2
        r = np.sqrt(r_sq)
        factor = (-1 / np.pi) * self.cm * (1 - self.nu) * delta_P * self.V
        denominator_base = r**2 + self.D**2
        denominator_base[denominator_base < 1e-9] = 1e-9
        uz = factor * (self.D / (denominator_base**1.5))
        ur = factor * (r / (denominator_base**1.5))
        azimuth = np.arctan2(y - cy, x - cx)
        ux = ur * np.cos(azimuth)
        uy = ur * np.sin(azimuth)
        los_vector_calc = Noise2NoiseDataset.calculate_los_vector(self.incidence_angle_deg, self.satellite_azimuth_deg)
        simulated_interferogram = (ux * los_vector_calc[0]) + \
                                  (uy * los_vector_calc[1]) + \
                                  (uz * los_vector_calc[2])
        return simulated_interferogram

    def get_times(self):
        if self.total_days < 1 or self.interval_days <=0:
             return np.array([1.0])
        return np.arange(1, self.total_days + 1, self.interval_days)

    def _get_clean_subsidence_image(self, t):
        delta_P_final = -self.S_max * ((np.pi * self.D**2) / (self.cm * (1 - self.nu) * self.V))
        current_time_factor = 0.0
        if callable(self.f_t):
            if len(self.times) > 0:
                current_time_factor = self.f_t(t, self.times, self.total_time)
            else:
                current_time_factor = t / self.total_time if self.total_time > 0 else 0
        else: 
            current_time_factor = t / self.total_time if self.total_time > 0 else 0
        delta_P_current = -delta_P_final * current_time_factor
        return self.generate_subsidence(delta_P=delta_P_current)

    def _get_orbit_geometry(self):
        if self.orbit_type == 'ascending': return 40, 15
        elif self.orbit_type == 'descending': return 40, 195
        else: raise ValueError("Invalid orbit type.")

    def __getitem__(self, idx):
        current_time = self.times[idx]
        clean_image = self._get_clean_subsidence_image(current_time)
        noise1_random = self.generate_random_noise()
        noise1_tropo = self.generate_tropospheric_noise()
        noisy_image1 = clean_image + noise1_random + noise1_tropo
        noise2_random = self.generate_random_noise()
        noise2_tropo = self.generate_tropospheric_noise()
        noisy_image2 = clean_image + noise2_random + noise2_tropo
        clean_image_tensor = torch.from_numpy(clean_image.copy()).float().unsqueeze(0)
        noisy_image1_tensor = torch.from_numpy(noisy_image1.copy()).float().unsqueeze(0)
        noisy_image2_tensor = torch.from_numpy(noisy_image2.copy()).float().unsqueeze(0)
        return noisy_image1_tensor, noisy_image2_tensor, clean_image_tensor

def f_linear(t, times_array, total_time_val):
    if total_time_val == 0: return 0
    return t / total_time_val

def f_log(t, times_array, total_time_val):
    if total_time_val <= 0: return 0
    if t <= 0: t = 1e-6 
    min_time_in_series = times_array[0] if len(times_array)>0 and times_array[0] > 0 else 1.0
    adjusted_t = (t - min_time_in_series) + 1
    adjusted_total_time = (total_time_val - min_time_in_series) + 1
    if adjusted_total_time <= 1:
        return 1.0 if t >= total_time_val else (t/total_time_val if total_time_val > 0 else 0)
    val = np.log1p(max(0, adjusted_t-1)) / np.log1p(max(1e-7, adjusted_total_time-1))
    return min(max(0, val), 1.0)

In [6]:
IMG_SIZE = (1500, 1500)
S_MAX = 5.0
D_DEPTH = 50
NU = 0.25
CM = 1.0
V_PARAM = 1.0
RANDOM_NOISE_STD = 0.56
TROPOSPHERIC_NOISE_BETA = 1.82
TROPOSPHERIC_NOISE_SCALE = 1.0
TOTAL_DAYS = 1460
INTERVAL_DAYS = 49

# IMPORTANT: Change this to your Google Drive path if using Drive
PRECOMPUTED_DATA_ROOT = r"Data"

def generate_and_save_data():
    print(f"Starting data generation. Files will be saved to: {PRECOMPUTED_DATA_ROOT}")
    os.makedirs(os.path.join(PRECOMPUTED_DATA_ROOT, 'noisy1'), exist_ok=True)
    os.makedirs(os.path.join(PRECOMPUTED_DATA_ROOT, 'noisy2'), exist_ok=True)
    os.makedirs(os.path.join(PRECOMPUTED_DATA_ROOT, 'clean'), exist_ok=True)

    manifest_data = []
    global_sample_idx = 0

    dataset_configs = [
        {'name': 'lin_asc_16', 'f_t': -1, 'orbit': 'ascending', 'interval': 16},
        {'name': 'log_asc_16', 'f_t': f_log, 'orbit': 'ascending', 'interval': 16},
        {'name': 'lin_desc_16', 'f_t': -1, 'orbit': 'descending', 'interval': 16},
        {'name': 'log_desc_16', 'f_t': f_log, 'orbit': 'descending', 'interval': 16},
        {'name': 'lin_asc_49', 'f_t': -1, 'orbit': 'ascending', 'interval': 49},
        {'name': 'log_asc_49', 'f_t': f_log, 'orbit': 'ascending', 'interval': 49},
        {'name': 'lin_desc_49', 'f_t': -1, 'orbit': 'descending', 'interval': 49},
        {'name': 'log_desc_49', 'f_t': f_log, 'orbit': 'descending', 'interval': 49}
    ]

    for config in dataset_configs:
        print(f"Generating data for config: {config['name']}")
        gen_dataset = Noise2NoiseDataset(
            size=IMG_SIZE, S_max=S_MAX, D=D_DEPTH, nu=NU, cm=CM, V=V_PARAM,
            random_noise_std=RANDOM_NOISE_STD, tropospheric_noise_beta=TROPOSPHERIC_NOISE_BETA,
            tropospheric_noise_scale=TROPOSPHERIC_NOISE_SCALE, total_days=TOTAL_DAYS,
            interval_days=INTERVAL_DAYS, f_t=config['f_t'], orbit_type=config['orbit']
        )
        
        if len(gen_dataset) == 0:
            print(f"Warning: No samples generated for config {config['name']}. Check total_days and interval_days.")
            continue

        for i in range(len(gen_dataset)):
            try:
                noisy1_tensor, noisy2_tensor, clean_tensor = gen_dataset[i]

                noisy1_fname = os.path.join('noisy1', f'sample_{global_sample_idx:06d}_noisy1.pt')
                noisy2_fname = os.path.join('noisy2', f'sample_{global_sample_idx:06d}_noisy2.pt')
                clean_fname = os.path.join('clean', f'sample_{global_sample_idx:06d}_clean.pt')

                torch.save(noisy1_tensor, os.path.join(PRECOMPUTED_DATA_ROOT, noisy1_fname))
                torch.save(noisy2_tensor, os.path.join(PRECOMPUTED_DATA_ROOT, noisy2_fname))
                torch.save(clean_tensor, os.path.join(PRECOMPUTED_DATA_ROOT, clean_fname))

                manifest_data.append({
                    'id': global_sample_idx,
                    'config_name': config['name'],
                    'original_idx_in_config': i,
                    'time_step': gen_dataset.times[i],
                    'noisy1_path': noisy1_fname,
                    'noisy2_path': noisy2_fname,
                    'clean_path': clean_fname
                })
                global_sample_idx += 1
                if global_sample_idx % 10 == 0: # Print progress
                    print(f"Saved sample {global_sample_idx}...")

            except Exception as e:
                print(f"Error generating/saving sample {global_sample_idx} (original index {i} in config {config['name']}): {e}")
                # Optionally, decide if you want to skip or halt on error
                continue
        print(f"Finished generating for config: {config['name']}. Total samples so far: {global_sample_idx}")


    manifest_df = pd.DataFrame(manifest_data)
    manifest_path = os.path.join(PRECOMPUTED_DATA_ROOT, 'manifest.csv')
    manifest_df.to_csv(manifest_path, index=False)
    print(f"Data generation complete. Total samples: {global_sample_idx}. Manifest saved to {manifest_path}")

generate_and_save_data()

Starting data generation. Files will be saved to: Data
Generating data for config: lin_asc_16
Saved sample 10...
Saved sample 20...
Saved sample 30...
Finished generating for config: lin_asc_16. Total samples so far: 30
Generating data for config: log_asc_16
Saved sample 40...
Saved sample 50...
Saved sample 60...
Finished generating for config: log_asc_16. Total samples so far: 60
Generating data for config: lin_desc_16
Saved sample 70...
Saved sample 80...
Saved sample 90...
Finished generating for config: lin_desc_16. Total samples so far: 90
Generating data for config: log_desc_16
Saved sample 100...
Saved sample 110...
Saved sample 120...
Finished generating for config: log_desc_16. Total samples so far: 120
Generating data for config: lin_asc_49
Saved sample 130...
Saved sample 140...
Saved sample 150...
Finished generating for config: lin_asc_49. Total samples so far: 150
Generating data for config: log_asc_49
Saved sample 160...
Saved sample 170...
Saved sample 180...
Finished 