# 01. Advanced Data Simulation for Gravitational Lensing Super-Resolution

**Objective**: Generate high-fidelity simulated gravitational lensing images to train the SwinIR model. 

This notebook covers:
1. **SIE Lens Modeling**: Simulating Singular Isothermal Ellipsoid mass distributions.
2. **Real Galaxy Sources**: Using `Galaxy10_DECals` images as the source light to achieve realistic morphological variety.
3. **Multi-Resolution Simulation**: Creating paired High-Resolution (HR) and Low-Resolution (LR) images based on Euclid mission specifications.
4. **Data Preprocessing**: Normalization, filtering, and dataset preparation.

### 1. Setup and Dependencies

In [None]:
import copy
import numpy as np
import h5py
import random
import os
import cv2
from collections import defaultdict
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

from lenstronomy.SimulationAPI.sim_api import SimAPI
from lenstronomy.SimulationAPI.ObservationConfig.Euclid import Euclid
from astropy.cosmology import FlatLambdaCDM

# Configuration
OUTPUT_DIR = "../pairs"
DATA_FILE = r"../../DeepLenseSim/data/Galaxy10_DECals.h5"
PROCESSED_DATA_DIR = "../data_diff"
NUM_SIMS = 2500
SEED = 42

np.random.seed(SEED)
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(PROCESSED_DATA_DIR, exist_ok=True)

print(f"Target Output Directory: {OUTPUT_DIR}")

### 2. Simulation Physics Engine
We use `lenstronomy` to simulate the gravitational lensing effect. We define a standard Euclid-like instrument configuration and adjust the pixel scale for super-resolution.

In [None]:
cosmo = FlatLambdaCDM(H0=70, Om0=0.3, Ob0=0.)
Euclid_g = Euclid(band='VIS', psf_type='GAUSSIAN', coadd_years=6)
kwargs_g_band = Euclid_g.kwargs_single_band()

def get_simulation_api(numpix, kwargs_band):
    kwargs_model = {'lens_model_list': ['SIE'], 
                    'lens_redshift_list': [0.5], 
                    'lens_light_model_list': ['SERSIC_ELLIPSE', 'SERSIC_ELLIPSE'], 
                    'source_light_model_list': ['INTERPOL'], 
                    'source_redshift_list': [1.0], 
                    'cosmo': cosmo, 
                    'z_source_convention': 2.5, 
                    'z_source': 2.5}
    return SimAPI(numpix=numpix, kwargs_single_band=kwargs_band, kwargs_model=kwargs_model)

def simulate_pair(image_galaxy, sigma_v, source_pos_xx, source_pos_yy, source_ang, idx):
    # HR Config (Upscale Factor x2)
    kwargs_g_hr = copy.deepcopy(kwargs_g_band)
    kwargs_g_hr['pixel_scale'] = 0.05 
    sim_hr = get_simulation_api(128, kwargs_g_hr)
    imSim_hr = sim_hr.image_model_class({'point_source_supersampling_factor': 1})
    
    # LR Config (Base Resolution)
    kwargs_g_lr = copy.deepcopy(kwargs_g_band)
    kwargs_g_lr['pixel_scale'] = 0.1 
    sim_lr = get_simulation_api(64, kwargs_g_lr)
    imSim_lr = sim_lr.image_model_class({'point_source_supersampling_factor': 1})
    
    # Lens Mass
    kwargs_mass = [{'sigma_v': sigma_v, 'center_x': 0, 'center_y': 0, 'e1': 0.0, 'e2': 0}]
    
    # Process Source Light (Real Galaxy)
    image_data = image_galaxy[:,:,0].astype(float)
    image_data -= np.median(image_data[:50, :50])
    
    kwargs_source_mag = [{'magnitude': 22, 'image': image_data, 'scale': 0.0025, 'phi_G': source_ang, 'center_x': source_pos_xx, 'center_y': source_pos_yy}]
    
    # Lens Light (Elliptical Sersic Profiles)
    kwargs_lens_light_mag = [{'magnitude': 17, 'R_sersic': 0.4, 'n_sersic': 2.3, 'e1': 0, 'e2': 0.05, 'center_x': 0, 'center_y': 0},
                             {'magnitude': 28, 'R_sersic': 1.5, 'n_sersic': 1.2, 'e1': 0, 'e2': 0.3, 'center_x': 0, 'center_y': 0}]

    kwargs_lens_light_hr, kwargs_source_hr, _ = sim_hr.magnitude2amplitude(kwargs_lens_light_mag, kwargs_source_mag)
    kwargs_lens_light_lr, kwargs_source_lr, _ = sim_lr.magnitude2amplitude(kwargs_lens_light_mag, kwargs_source_mag)
    kwargs_lens = sim_hr.physical2lensing_conversion(kwargs_mass=kwargs_mass)

    # Generate Images
    image_hr = imSim_hr.image(kwargs_lens, kwargs_source_hr, kwargs_lens_light_hr)
    image_lr = imSim_lr.image(kwargs_lens, kwargs_source_lr, kwargs_lens_light_lr)

    # Add Instrument Noise
    image_hr += sim_hr.noise_for_model(model=image_hr) * 0.5 
    image_lr += sim_lr.noise_for_model(model=image_lr)

    return image_hr, image_lr

### 3. Load Real Galaxies
We load unbarred spiral galaxies from the `Galaxy10_DECals` dataset.

In [None]:
if not os.path.exists(DATA_FILE):
    raise FileNotFoundError(f"Galaxy10 data not found at {DATA_FILE}")

with h5py.File(DATA_FILE, 'r') as F:
    images = np.array(F['images'])
    labels = np.array(F['ans'])
    redshift = np.array(F['redshift'])

# Filter for specific morphology if needed (Class 6: Unbarred Spiral)
spiral_indices = np.where((labels == 6) & (redshift < 0.02))[0]
print(f"Found {len(spiral_indices)} suitable galaxies.")

### 4. Run Simulation Loop

In [None]:
print(f"Simulating {NUM_SIMS} pairs...")
for i in tqdm(range(NUM_SIMS)):
    gal_idx = spiral_indices[np.random.randint(0, len(spiral_indices))]
    
    # Stochastic Lensing Parameters
    sigma_v = np.random.normal(260, 20)
    source_pos_x = np.random.uniform(-0.3, 0.3)
    source_pos_y = np.random.uniform(-0.3, 0.3)
    source_ang = np.random.uniform(-np.pi, np.pi)
    
    hr, lr = simulate_pair(images[gal_idx], sigma_v, source_pos_x, source_pos_y, source_ang, i)
    
    np.save(os.path.join(OUTPUT_DIR, f'{i}_lensing_hsc.npy'), hr)
    np.save(os.path.join(OUTPUT_DIR, f'{i}_lensing_hst.npy'), lr)

### 5. Post-Processing & Dataset Aggregation
This section consolidates the individual `.npy` files into single training tensors and performs normalization ($[-1, 1]$ range).

In [None]:
def process_and_save():
    files = [f for f in os.listdir(OUTPUT_DIR) if f.endswith(".npy")]
    HR_list, LR_list = [], []
    indices = sorted(list(set([int(f.split('_')[0]) for f in files])))
    
    for idx in tqdm(indices, desc="Processing"):
        hr_path = os.path.join(OUTPUT_DIR, f"{idx}_lensing_hsc.npy")
        lr_path = os.path.join(OUTPUT_DIR, f"{idx}_lensing_hst.npy")
        
        # Load and Normalize
        hr = np.load(hr_path)
        lr = np.load(lr_path)
        
        # Min-Max Scaling -> [-1, 1]
        hr = (hr - np.min(hr)) / (np.max(hr) - np.min(hr) + 1e-7)
        lr = (lr - np.min(lr)) / (np.max(lr) - np.min(lr) + 1e-7)
        
        # Resizing to exact targets (Standardizing)
        hr = cv2.resize(hr, (128, 128), interpolation=cv2.INTER_CUBIC)
        lr = cv2.resize(lr, (64, 64), interpolation=cv2.INTER_CUBIC)
        
        hr = 2 * hr - 1
        lr = 2 * lr - 1
        
        HR_list.append(hr[np.newaxis, ...])
        LR_list.append(lr[np.newaxis, ...])
        
    HR = np.array(HR_list)
    LR = np.array(LR_list)
    
    # Split 90/10 for train/test
    split_idx = int(0.9 * len(HR))
    
    np.save(os.path.join(PROCESSED_DATA_DIR, 'train_HR.npy'), HR[:split_idx])
    np.save(os.path.join(PROCESSED_DATA_DIR, 'train_LR.npy'), LR[:split_idx])
    np.save(os.path.join(PROCESSED_DATA_DIR, 'test_HR.npy'), HR[split_idx:])
    np.save(os.path.join(PROCESSED_DATA_DIR, 'test_LR.npy'), LR[split_idx:])
    
    print(f"Saved processed data to {PROCESSED_DATA_DIR}")
    print(f"Train Shape: {HR[:split_idx].shape}")
    
process_and_save()

### 6. Visualization
Verify the resolution difference visually.

In [None]:
HR = np.load(os.path.join(PROCESSED_DATA_DIR, 'train_HR.npy'))
LR = np.load(os.path.join(PROCESSED_DATA_DIR, 'train_LR.npy'))

idx = np.random.randint(0, len(HR))
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.imshow(LR[idx].squeeze(), cmap='inferno')
plt.title("Low Resolution Input (64x64)")
plt.subplot(1, 2, 2)
plt.imshow(HR[idx].squeeze(), cmap='inferno')
plt.title("High Resolution Ground Truth (128x128)")
plt.show()