In [None]:
import numpy as np
import os
import matplotlib.pyplot as plt
from scipy import ndimage
from scipy.ndimage import uniform_filter
from scipy.stats import gamma as gamma_dist
from pathlib import Path
from tqdm.notebook import tqdm
from glob import glob
import sys
import gc
import psutil

project_root = Path.cwd().parents[1] # Up from learningnotebooks/phase4_sar_codec/
print(project_root)
sys.path.insert(0, str(project_root / "src"))

from utils.io import load_sar_image, find_all_sar_files, get_info
from data.preprocessing import preprocess_sar_complete, extract_patches


rawdir = project_root/"data"/"raw"
outputdir = project_root/"data"/"patches"
checkpointdir = project_root/"checkpoints"
patchsize = 256
stride = 128
print(rawdir, outputdir)

In [None]:
sar_files = find_all_sar_files(rawdir)

In [None]:
safe_folders = list(rawdir.glob("*.SAFE"))
print(f"Total .SAFE folders: {len(safe_folders)}\n")

In [None]:
complete = []
incomplete = []

for folder in safe_folders:
    measurement_dir = folder / "measurement"
    if measurement_dir.exists():
        tiffs = list(measurement_dir.glob("*.tiff"))
        if len(tiffs) >= 2:
            complete.append(folder.name)
        else:
            incomplete.append((folder.name, len(tiffs)))
    else:
        incomplete.append((folder.name, "no measurement folder"))

print(f"Complete: {len(complete)}")
print(f"Incomplete: {len(incomplete)}\n")

if incomplete:
    print("Incomplete folders:")
    for name, issue in incomplete:
        print(f"  {name}: {issue}")

In [None]:
folder = rawdir / "S1A_IW_GRDH_1SDV_20260117T061452_20260117T061517_062803_07E08D_824A.SAFE"

print(f"Folder exists: {folder.exists()}")
print(f"\nContents:")
for item in folder.iterdir():
    print(f"  {item.name}")
    
# Check if measurement exists with different casing
print(f"\nMeasurement folder check:")
print(f"  'measurement' exists: {(folder / 'measurement').exists()}")
print(f"  'MEASUREMENT' exists: {(folder / 'MEASUREMENT').exists()}")

In [None]:
print(f"Found {len(sar_files)} TIFF files:\n")
for f in sar_files:
    info = get_info(f)
    print(f"  {info['satellite'].upper()} | {info.get('date', '?')} | {Path(f).name[:50]}...")

In [None]:
path = sar_files[0]
import rasterio
import time
with rasterio.open(path) as src:
    print(f"Shape: {src.shape}")
    print(f"Dtype: {src.dtypes[0]}")
    image = src.read(1).astype(np.float32)
    print(f"Array size: {image.nbytes / 1e9:.2f} GB")

## Plotting bands of the image test

In [None]:

# def showImage(raster,tindex,bandnbr,vmin=None,vmax=None): 
#     fig = plt.figure(figsize=(16,8)) 
#     ax1 = fig.add_subplot( 2 ) 
#     2 = fig.add_subplot( 22) 
#     ax1.imshow(raster,cmap='gray',vmin=vmin,vmax=vmax) 
#     ax1.set_title('Image Band {} {}'.format(bandnbr, 
#     tindex[bandnbr-1].date())) 
#     vmin=np.percentile(raster,2) if vmin==None else vmin 
#     vmax=np.percentile(raster,98) if vmax==None else vmax 
#     ax1.xaxis.set_label_text( 
#     'Linear stretch Min={} Max={}'.format(vmin,vmax)) 
#     h = 2.hist(raster.flatten(),bins=100,range=(0,8000)) 
#     2.xaxis.set_label_text('Amplitude (Uncalibrated DN Values)') 
#     2.set_title('Histogram Band {} {}'.format(bandnbr, 
#     tindex[bandnbr-1].date())

In [None]:
testtiff = rasterio.open(f'{rawdir}\S1A_IW_GRDH_1SDV_20260116T113541_20260116T113606_062792_07E02D_AC54.SAFE\measurement\s1a-iw-grd-vh-20260116t113541-20260116t113606-062792-07e02d-002.tiff')
print(testtiff.count, testtiff.mode)


: 

In [None]:
print(testtiff.width, testtiff.height)

## Continuing preprocessing task

In [None]:
bounds_file = checkpointdir / "global_bounds.npy"

if bounds_file.exists():
    bounds = np.load(bounds_file, allow_pickle=True).item()
    global_vmin = bounds["vmin"]
    global_vmax = bounds["vmax"]
    print(f"Loaded bounds: [{global_vmin:.2f}, {global_vmax:.2f}] dB")
else:
    # Run the scanning loop to compute them
    print("No saved bounds, computing...")

    all_vmins = []
    all_vmaxs = []



    for path in tqdm(sar_files, desc="Scanning"):
        print(f"\n{path}")
        
        t0 = time.time()
        with rasterio.open(path) as src:
            image = src.read(1).astype(np.float32)
        print(f"  Read: {time.time() - t0:.1f}s")
        
        t0 = time.time()
        valid = image[image > 0]
        del image
        gc.collect()
        print(f"  Valid mask: {time.time() - t0:.1f}s")
        
        if len(valid) == 0:
            continue
        
        t0 = time.time()
        if len(valid) > 100_000:
            idx = np.random.randint(0, len(valid), 100_000)
            valid = valid[idx]
        print(f"  Subsample: {time.time() - t0:.1f}s")
        
        t0 = time.time()
        image_db = 10 * np.log10(np.maximum(valid, 1e-10))
        all_vmins.append(np.percentile(image_db, 1))
        all_vmaxs.append(np.percentile(image_db, 99))
        print(f"  Percentiles: {time.time() - t0:.1f}s")
        
        del valid, image_db
        gc.collect()

    global_vmin = np.median(all_vmins)
    global_vmax = np.median(all_vmaxs)


    print(f"\nGlobal bounds: [{global_vmin:.2f}, {global_vmax:.2f}] dB")
    np.save(checkpointdir / "global_bounds.npy", {
        "vmin" : global_vmin,
        "vmax" : global_vmax,
        "all_vmins" : all_vmins,
        "all_vmaxs" : all_vmaxs
    })

In [None]:

min_valid = 0.9

all_patches = []
filestats = []

print("Pass 2 Extracting Patches...\n")

for path in tqdm(sar_files, desc = "Processing"):
    filename = Path(path).stem

    mem_gb = psutil.Process().memory_info().rss / 1e9
    if mem_gb > 25:  # bail out before crash
        print(f"WARNING: Memory at {mem_gb:.1f} GB, stopping early")
        break

    image = load_sar_image(path)
    image_shape=image.shape

    normalised, params = preprocess_sar_complete(
        image, 
        vmin=global_vmin, 
        vmax=global_vmax
    )
    del image
    gc.collect()
    
    # Extract patches
    patches, positions = extract_patches(
        normalised,
        patch_size=patchsize,
        stride=stride,
        min_valid=min_valid
    )
    del normalised
    gc.collect()
    
    n_patches = len(patches)
    filestats.append({
        'filename': filename,
        'shape': image_shape,
        'patches': n_patches
    })
    if n_patches > 0:
        np.save(outputdir / f"{filename}_patches.npy", patches)

    del patches, positions
    gc.collect()
    
    
    tqdm.write(f"  {filename[:40]}... : {n_patches} patches | Mem: {mem_gb:.1f} GB")
    

print(f"\nProcessed {len(sar_files)} files")

In [None]:
from pathlib import Path

output_dir = Path("patches")

# Check what's done
existing = list(output_dir.glob("*.npy"))
print(f"Files saved: {len(existing)}")

# Check disk space
import shutil
total, used, free = shutil.disk_usage("D:/")
print(f"Free space: {free / 1e9:.1f} GB")

In [None]:
existing = set(f.stem.replace("_patches", "") for f in outputdir.glob("*.npy"))
print(f"Already saved: {len(existing)} files")

# Find remaining files
remaining = [p for p in sar_files if Path(p).stem not in existing]
print(f"Remaining: {len(remaining)} files")

# Process only remaining files
for path in tqdm(remaining, desc="Processing"):
    filename = Path(path).stem
    
    mem_gb = psutil.Process().memory_info().rss / 1e9
    if mem_gb > 25:
        print(f"WARNING: Memory at {mem_gb:.1f} GB, stopping early")
        break
    
    image = load_sar_image(path)
    image_shape = image.shape
    
    normalised, params = preprocess_sar_complete(
        image, vmin=global_vmin, vmax=global_vmax
    )
    del image
    gc.collect()
    
    patches, positions = extract_patches(
        normalised, patch_size=patchsize, stride=stride, min_valid=min_valid
    )
    del normalised
    gc.collect()
    
    n_patches = len(patches)
    filestats.append({
        'filename': filename,
        'shape': image_shape,
        'patches': n_patches
    })
    
    if n_patches > 0:
        np.save(outputdir / f"{filename}_patches.npy", patches)
    
    del patches, positions
    gc.collect()
    
    tqdm.write(f"  {filename[:40]}... : {n_patches} patches | Mem: {mem_gb:.1f} GB")

print("Done!")

In [None]:
# Get all tiff filenames (stems)
tiff_stems = set(Path(p).stem for p in sar_files)

# Get all saved patch filenames (remove "_patches" suffix)
saved_stems = set(f.stem.replace("_patches", "") for f in outputdir.glob("*.npy"))

# Find missing
missing = tiff_stems - saved_stems

print(f"Total TIFFs: {len(tiff_stems)}")
print(f"Saved patches: {len(saved_stems)}")
print(f"Missing: {len(missing)}")

if missing:
    print("\nMissing files:")
    for m in sorted(missing):
        print(f"  {m}")

In [None]:
import struct

def get_npy_shape(filepath):
    """Read shape from .npy header without loading data"""
    with open(filepath, 'rb') as f:
        # Skip magic number and version
        f.read(8)
        # Read header length
        header_len = struct.unpack('<H', f.read(2))[0]
        # Read header
        header = f.read(header_len).decode('latin1')
        # Parse shape from header string
        shape_start = header.find('(') + 1
        shape_end = header.find(')')
        shape_str = header[shape_start:shape_end]
        shape = tuple(int(x.strip()) for x in shape_str.split(',') if x.strip())
        return shape

total_patches = 0
for f in outputdir.glob("*_patches.npy"):
    shape = get_npy_shape(f)
    total_patches += shape[0]
    
print(f"Total patches: {total_patches:,}")

In [None]:
total_bytes = sum(f.stat().st_size for f in outputdir.glob("*_patches.npy"))
print(f"Total size: {total_bytes / 1e9:.1f} GB")

In [None]:
patch_files = list(outputdir.glob("*patches.npy"))

sample = np.load(patch_files[0], mmap_mode="r")
patch_shape = sample.shape[1:]
print(f"Patch shape: {patch_shape}")

In [None]:
# Save metadata only
metadata = {
    'vmin': global_vmin,
    'vmax': global_vmax,
    'patch_size': patchsize,
    'stride': stride,
    'min_valid': min_valid,
    'num_patches': total_patches,
}
np.save(outputdir / 'metadata.npy', metadata, allow_pickle=True)

# Save shuffle index
shuffle_idx = np.random.permutation(total_patches)
np.save(outputdir / "shuffle_idx.npy", shuffle_idx)

In [None]:
# Delete the old combined file
(outputdir / "all_patches.npy").unlink()

# Re-run indexing (exclude all_patches.npy)
patch_files = sorted(f for f in outputdir.glob("*_patches.npy") if f.name != "all_patches.npy")

file_index = []
for f in tqdm(patch_files, desc="Indexing"):
    shape = get_npy_shape(f)
    file_index.append((f, shape[0]))

total_patches = sum(n for _, n in file_index)
print(f"Total patches: {total_patches:,}")

np.random.seed(42)
shuffle_idx = np.random.permutation(total_patches)
np.save(outputdir / "shuffle_idx.npy", shuffle_idx)

metadata = {
    'vmin': global_vmin,
    'vmax': global_vmax,
    'patch_size': patchsize,
    'stride': stride,
    'min_valid': min_valid,
    'num_patches': total_patches,
    'file_index': [(str(f), n) for f, n in file_index]
}
np.save(outputdir / 'metadata.npy', metadata, allow_pickle=True)

print("Done!")

In [None]:
# List all patch files and their sizes
for f in sorted(outputdir.glob("*_patches.npy")):
    shape = get_npy_shape(f)
    print(f"{f.name}: {shape[0]:,} patches")

In [None]:
# Cell 6: Verify
import matplotlib.pyplot as plt

# Load metadata
metadata = np.load(outputdir / 'metadata.npy', allow_pickle=True).item()
shuffle_idx = np.load(outputdir / "shuffle_idx.npy")

print(f"Total patches: {metadata['num_patches']:,}")
print(f"Global bounds: [{metadata['vmin']:.2f}, {metadata['vmax']:.2f}] dB")
print(f"Patch size: {metadata['patch_size']}")
print(f"Files: {len(metadata['file_index'])}")

# Load a few random patches for visualization
def load_random_patches(n=12):
    """Load n random patches from the dataset"""
    patches = []
    
    # Build cumsum for file lookup
    cumsum = [0]
    for _, count in metadata['file_index']:
        cumsum.append(cumsum[-1] + count)
    
    random_indices = np.random.choice(len(shuffle_idx), n, replace=False)
    
    # Group by file to minimize loads
    file_requests = {}
    for idx in random_indices:
        real_idx = shuffle_idx[idx]
        for i, (start, end) in enumerate(zip(cumsum[:-1], cumsum[1:])):
            if start <= real_idx < end:
                local_idx = real_idx - start
                if i not in file_requests:
                    file_requests[i] = []
                file_requests[i].append(local_idx)
                break
    
    # Load from each file
    for file_idx, local_indices in file_requests.items():
        fpath, _ = metadata['file_index'][file_idx]
        data = np.load(fpath)  # Full load, no mmap
        for local_idx in local_indices:
            patches.append(data[local_idx])
        del data
        gc.collect()
    
    return patches

# Show random samples
samples = load_random_patches(12)

fig, axes = plt.subplots(3, 4, figsize=(12, 9))
for ax, patch in zip(axes.flatten(), samples):
    ax.imshow(patch, cmap='gray', vmin=0, vmax=1)
    ax.axis('off')
plt.suptitle(f"Random samples from {metadata['num_patches']:,} patches")
plt.tight_layout()
plt.show()