Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions lars/preprocessing/radar_preprocessing.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import xradar as xd
import matplotlib.pyplot as plt
import glob
import numpy as np
import os
import pandas as pd
import cmweather # noqa
import cmweather # noqa


def preprocess_radar_data(file_path, output_path, date=None,
Expand Down Expand Up @@ -66,19 +67,23 @@ def preprocess_radar_data(file_path, output_path, date=None,
radar = radar.xradar.georeference()
if 'sweep_0' in radar:
sweep = radar['sweep_0']
if sweep["sweep_mode"] == 'ppi' or sweep["sweep_mode"] == 'sector':
sweep_mode = str(sweep["sweep_mode"].values).split('\x00')[0].strip()
if sweep_mode in ('ppi', 'sector', 'azimuth_surveillance'):
fig = plt.figure(figsize=(size_px/dpi, size_px/dpi))
ax = plt.axes()
sweep["corrected_reflectivity"].where(
sweep["corrected_reflectivity"] > min_ref).plot(x="x", y="y",
ax=ax,
add_colorbar=False,
**kwargs)
min_ref = sweep["corrected_reflectivity"].where(
sweep["corrected_reflectivity"] > min_ref).values.min()
max_ref = sweep["corrected_reflectivity"].where(
sweep["corrected_reflectivity"] > min_ref).values.max()

masked = sweep["corrected_reflectivity"].where(
sweep["corrected_reflectivity"] > min_ref).values
ref_min = np.nanmin(masked)
ref_max = np.nanmax(masked)
ax.axis('off')
ax.set_title('')
ax.set_ylabel('')
ax.set_xlabel('')
ax.set_xlim(x_bounds)
ax.set_ylim(y_bounds)

Expand All @@ -91,7 +96,7 @@ def preprocess_radar_data(file_path, output_path, date=None,
os.path.basename(file).replace('.nc', '.png')),
dpi=dpi, bbox_inches='tight', pad_inches=0)
plt.close(fig)
out_df.loc[len(out_df)] = [file_name, time_str, label, min_ref, max_ref]
out_df.loc[len(out_df)] = [file_name, time_str, label, ref_min, ref_max]

else:
print(f"Sweep mode is not PPI or sector scan in {file}, skipping.")
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ dependencies = ["xradar", "scikit-learn", "python-dotenv", "aiohttp", "asksagecl
[project.optional-dependencies]
dev = ["pytest>=6.0", "pytest-asyncio>=0.21", "black", "flake8", "openai", "xradar", "python-dotenv",
"scikit-learn", "cmweather", "torchvision", "torch", "aiohttp", "matplotlib", "pandas",
"asksageclient", "pip_system_certs", "requests"]
"asksageclient", "pip_system_certs", "requests", "open-radar-data"]

[project.urls]
Homepage = "https://github.com/rcjackson/lars"
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
25 changes: 25 additions & 0 deletions tests/integration/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import sys

# The parent tests/conftest.py mocks these at import time. Remove the mocks
# so integration tests can use the real implementations.
_MOCKED = ["xradar", "cmweather", "asksageclient", "pip_system_certs"]
for _key in list(sys.modules):
if any(_key == m or _key.startswith(m + ".") for m in _MOCKED):
del sys.modules[_key]

# Evict any cached lars imports so they re-link against the real deps.
for _key in list(sys.modules):
if _key == "lars" or _key.startswith("lars."):
del sys.modules[_key]

import matplotlib
matplotlib.use("Agg")


def pytest_addoption(parser):
parser.addoption(
"--generate-baseline",
action="store_true",
default=False,
help="Write baseline images instead of comparing against them.",
)
144 changes: 144 additions & 0 deletions tests/integration/test_radar_preprocessing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
"""
Integration tests for lars.preprocessing.preprocess_radar_data.

Downloads three GUC XPRECIP CMAC PPI radar files from open-radar-data,
runs the full preprocessing workflow, validates the returned DataFrame,
and compares each generated PNG against a stored baseline image.

Generating baselines (first-time setup or after intentional changes):
pytest tests/integration/ --generate-baseline

Running the tests normally:
pytest tests/integration/
"""

import os
import shutil

import matplotlib.image as mpimg
import numpy as np
import pytest

open_radar_data = pytest.importorskip("open_radar_data")
xradar = pytest.importorskip("xradar")

RADAR_FILES = [
"gucxprecipradarcmacppiS2.c1.20220314.021559.nc",
"gucxprecipradarcmacppiS2.c1.20220314.024239.nc",
"gucxprecipradarcmacppiS2.c1.20220314.025840.nc",
]

BASELINE_DIR = os.path.join(
os.path.dirname(__file__), "..", "data", "baseline", "preprocessing"
)

# Pixel-value tolerance for image comparison (values are float in [0, 1]).
IMAGE_TOLERANCE = 5 / 255


@pytest.fixture(scope="module")
def radar_data_dir(tmp_path_factory):
"""Download the three test radar files into an isolated temp directory."""
from open_radar_data import DATASETS

tmp_dir = tmp_path_factory.mktemp("radar_data")
for fname in RADAR_FILES:
src = DATASETS.fetch(fname)
shutil.copy(src, tmp_dir / fname)
return str(tmp_dir)


@pytest.fixture(scope="module")
def preprocessing_output(tmp_path_factory, radar_data_dir):
"""Run preprocessing once and share the output across all tests."""
from lars.preprocessing import preprocess_radar_data

out_dir = str(tmp_path_factory.mktemp("preprocessing_output"))
label_df = preprocess_radar_data(radar_data_dir, out_dir)
return out_dir, label_df


# ---------------------------------------------------------------------------
# DataFrame tests
# ---------------------------------------------------------------------------


def test_dataframe_row_count(preprocessing_output):
_, label_df = preprocessing_output
assert len(label_df) == 3


def test_dataframe_columns(preprocessing_output):
_, label_df = preprocessing_output
assert set(label_df.columns) == {"file_path", "label", "ref_min", "ref_max"}


def test_labels_are_unknown(preprocessing_output):
_, label_df = preprocessing_output
assert (label_df["label"] == "UNKNOWN").all()


def test_reflectivity_bounds(preprocessing_output):
_, label_df = preprocessing_output
assert (label_df["ref_min"] <= label_df["ref_max"]).all()


def test_timestamps_are_on_correct_date(preprocessing_output):
_, label_df = preprocessing_output
assert all("2022-03-14" in str(idx) for idx in label_df.index)


def test_index_is_sorted(preprocessing_output):
_, label_df = preprocessing_output
assert label_df.index.is_monotonic_increasing


# ---------------------------------------------------------------------------
# Image file tests
# ---------------------------------------------------------------------------


def test_png_files_created(preprocessing_output):
out_dir, _ = preprocessing_output
for fname in RADAR_FILES:
assert os.path.exists(os.path.join(out_dir, fname.replace(".nc", ".png")))


# ---------------------------------------------------------------------------
# Baseline image-comparison tests
# ---------------------------------------------------------------------------


def _compare_to_baseline(generated_path, baseline_path, tolerance):
generated = mpimg.imread(generated_path).astype(np.float32)
baseline = mpimg.imread(baseline_path).astype(np.float32)

assert generated.shape == baseline.shape, (
f"Shape mismatch: generated {generated.shape} vs baseline {baseline.shape}"
)
max_diff = np.max(np.abs(generated - baseline))
assert max_diff <= tolerance, (
f"Max pixel difference {max_diff:.4f} exceeds tolerance {tolerance:.4f} "
f"({os.path.basename(generated_path)})"
)


@pytest.mark.parametrize("fname", RADAR_FILES)
def test_image_matches_baseline(request, preprocessing_output, fname):
out_dir, _ = preprocessing_output
png_name = fname.replace(".nc", ".png")
generated_path = os.path.join(out_dir, png_name)
baseline_path = os.path.join(BASELINE_DIR, png_name)

if request.config.getoption("--generate-baseline"):
os.makedirs(BASELINE_DIR, exist_ok=True)
shutil.copy(generated_path, baseline_path)
pytest.skip(f"Baseline written to {baseline_path}")

if not os.path.exists(baseline_path):
pytest.skip(
f"No baseline found at {baseline_path}. "
"Run with --generate-baseline to create it."
)

_compare_to_baseline(generated_path, baseline_path, IMAGE_TOLERANCE)
Loading