From 3d1f83a419287890fc150ceebaf543bfc0017a62 Mon Sep 17 00:00:00 2001 From: Olufemi Taiwo Date: Wed, 15 Apr 2026 20:49:35 +0100 Subject: [PATCH 1/2] feat(inference): make pipeline analysis-aware with dynamic model loading - _load_model() now accepts analysis_type and reads in_channels/num_classes from config.yaml - Per-analysis-type model cache prevents cross-contamination between deforestation/ice/flood models - _find_best_checkpoint() prefers config.yaml weight path per analysis type - run_inference() accepts analysis_type, pads/crops to correct n_channels, and returns dynamic class counts - run_inference_from_file() and run_inference_from_gee() propagate analysis_type parameter --- src/climatevision/inference/pipeline.py | 219 +++++++++++++++--------- 1 file changed, 135 insertions(+), 84 deletions(-) diff --git a/src/climatevision/inference/pipeline.py b/src/climatevision/inference/pipeline.py index 77c6e30..9bbe25f 100644 --- a/src/climatevision/inference/pipeline.py +++ b/src/climatevision/inference/pipeline.py @@ -2,9 +2,9 @@ Inference pipeline for ClimateVision. Provides: -- run_inference(image_array, bbox, start_date, end_date) — core inference on a numpy array -- run_inference_from_file(path, bbox, start_date, end_date) — load file then infer -- run_inference_from_gee(bbox, start_date, end_date) — GEE NDVI + synthetic model inference +- run_inference(image_array, bbox, start_date, end_date, analysis_type) — core inference on a numpy array +- run_inference_from_file(path, bbox, start_date, end_date, analysis_type) — load file then infer +- run_inference_from_gee(bbox, start_date, end_date, analysis_type) — GEE NDVI + real tile inference """ from __future__ import annotations @@ -17,6 +17,7 @@ import numpy as np import torch +from climatevision.data.band_mapping import get_bands_for_analysis, get_model_config from climatevision.models.unet import UNet logger = logging.getLogger(__name__) @@ -29,10 +30,9 @@ _OUTPUTS_DIR = _PROJECT_ROOT / "outputs" # --------------------------------------------------------------------------- -# Singleton model cache +# Per-analysis-type model cache # --------------------------------------------------------------------------- -_cached_model: Optional[UNet] = None -_cached_device: Optional[torch.device] = None +_model_cache: dict[str, tuple[UNet, torch.device]] = {} def _get_device() -> torch.device: @@ -41,11 +41,18 @@ def _get_device() -> torch.device: return torch.device("cpu") -def _find_best_checkpoint() -> Optional[Path]: +def _find_best_checkpoint(analysis_type: str) -> Optional[Path]: """ - Search for the best available checkpoint. - Priority: models/best_model.pth > newest models/*/best_model.pth + Search for the best available checkpoint for an analysis type. + Priority: config.yaml weight path > models/best_model.pth > newest models/*/best_model.pth """ + model_cfg = get_model_config(analysis_type) + config_path = model_cfg.get("weights") + if config_path: + p = _PROJECT_ROOT / config_path + if p.exists(): + return p + direct = _MODELS_DIR / "best_model.pth" if direct.exists(): return direct @@ -57,17 +64,21 @@ def _find_best_checkpoint() -> Optional[Path]: return candidates[0] if candidates else None -def _load_model() -> tuple[UNet, torch.device]: - """Load (or return cached) U-Net model.""" - global _cached_model, _cached_device +def _load_model(analysis_type: str = "deforestation") -> tuple[UNet, torch.device]: + """Load (or return cached) U-Net model configured for the analysis type.""" + global _model_cache - if _cached_model is not None and _cached_device is not None: - return _cached_model, _cached_device + if analysis_type in _model_cache: + return _model_cache[analysis_type] device = _get_device() - model = UNet(n_channels=4, n_classes=2) + model_cfg = get_model_config(analysis_type) + n_channels = model_cfg.get("in_channels", 4) + n_classes = model_cfg.get("num_classes", 2) + + model = UNet(n_channels=n_channels, n_classes=n_classes) - model_path = _find_best_checkpoint() + model_path = _find_best_checkpoint(analysis_type) if model_path is not None: checkpoint = torch.load(model_path, map_location=device) @@ -85,21 +96,23 @@ def _load_model() -> tuple[UNet, torch.device]: param.data.copy_(ema_state[name]) logger.info( - "Loaded model from %s (epoch %s val_iou %.4f)", + "Loaded %s model from %s (epoch %s val_iou %.4f)", + analysis_type, model_path, checkpoint.get("epoch", "?"), checkpoint.get("val_iou", 0.0), ) else: logger.warning( - "No trained model found under %s — using untrained weights (demo).", _MODELS_DIR + "No trained model found for %s under %s — using untrained weights (demo).", + analysis_type, + _MODELS_DIR, ) model = model.to(device) model.eval() - _cached_model = model - _cached_device = device + _model_cache[analysis_type] = (model, device) return model, device @@ -193,6 +206,7 @@ def run_inference( bbox: Optional[list[float]] = None, start_date: Optional[str] = None, end_date: Optional[str] = None, + analysis_type: str = "deforestation", ) -> dict[str, Any]: """ Run full inference pipeline on a (C, H, W) numpy image. @@ -205,34 +219,54 @@ def run_inference( ndvi_stats = _compute_ndvi_stats(image) - # Prepare tensor — model expects (N, 4, H, W) + model, device = _load_model(analysis_type) + n_channels = model.n_channels + n_classes = model.n_classes + + # Prepare tensor — model expects (N, n_channels, H, W) c, h, w = image.shape - if c < 4: + if c < n_channels: # Pad missing channels with zeros - pad = np.zeros((4 - c, h, w), dtype=image.dtype) + pad = np.zeros((n_channels - c, h, w), dtype=image.dtype) image = np.concatenate([image, pad], axis=0) - elif c > 4: - image = image[:4] + elif c > n_channels: + image = image[:n_channels] # Use torch.FloatTensor via tolist() to avoid numpy<->torch interop issues - tensor = torch.FloatTensor(image.astype(np.float32).tolist()).unsqueeze(0) # (1, 4, H, W) - - model, device = _load_model() + tensor = torch.FloatTensor(image.astype(np.float32).tolist()).unsqueeze(0) # (1, C, H, W) tensor = tensor.to(device) with torch.no_grad(): output = model(tensor) predictions = torch.argmax(output, dim=1) # (1, H, W) - probabilities = torch.softmax(output, dim=1) # (1, 2, H, W) + probabilities = torch.softmax(output, dim=1) # (1, n_classes, H, W) - forest_pixels = int((predictions == 1).sum().item()) total_pixels = int(predictions.numel()) - non_forest_pixels = total_pixels - forest_pixels - forest_percentage = (forest_pixels / total_pixels) * 100 if total_pixels else 0.0 - max_probs = probabilities.max(dim=1).values mean_confidence = float(max_probs.mean().item()) + # Build per-class pixel counts + class_pixels: dict[str, int] = {} + class_percentages: dict[str, float] = {} + for cls in range(n_classes): + count = int((predictions == cls).sum().item()) + pct = (count / total_pixels) * 100 if total_pixels else 0.0 + class_pixels[f"class_{cls}_pixels"] = count + class_percentages[f"class_{cls}_percentage"] = round(pct, 4) + + # Add friendly keys for known 2-class deforestation output (backward compat) + inference: dict[str, Any] = { + "image_size": [h, w], + "num_classes": n_classes, + "mean_confidence": round(mean_confidence, 4), + **class_pixels, + **class_percentages, + } + if n_classes == 2: + inference["forest_pixels"] = class_pixels.get("class_1_pixels", 0) + inference["non_forest_pixels"] = class_pixels.get("class_0_pixels", 0) + inference["forest_percentage"] = class_percentages.get("class_1_percentage", 0.0) + region: dict[str, Any] = {} if bbox is not None: region["bbox"] = bbox @@ -242,13 +276,7 @@ def run_inference( return { "region": region, "ndvi_stats": ndvi_stats, - "inference": { - "image_size": [h, w], - "forest_pixels": forest_pixels, - "non_forest_pixels": non_forest_pixels, - "forest_percentage": round(forest_percentage, 4), - "mean_confidence": round(mean_confidence, 4), - }, + "inference": inference, } @@ -262,12 +290,19 @@ def run_inference_from_file( bbox: Optional[list[float]] = None, start_date: Optional[str] = None, end_date: Optional[str] = None, + analysis_type: str = "deforestation", ) -> dict[str, Any]: """ Load an image file (GeoTIFF or PNG/JPEG) and run inference. """ image = _load_image_file(path) - result = run_inference(image, bbox=bbox, start_date=start_date, end_date=end_date) + result = run_inference( + image, + bbox=bbox, + start_date=start_date, + end_date=end_date, + analysis_type=analysis_type, + ) result.setdefault("input", {})["file"] = path return result @@ -314,15 +349,13 @@ def run_inference_from_gee( bbox: Optional[list[float]] = None, start_date: Optional[str] = None, end_date: Optional[str] = None, + analysis_type: str = "deforestation", ) -> dict[str, Any]: """ - Query Google Earth Engine for NDVI stats and run model on synthetic data. - - GEE provides real NDVI statistics computed server-side. - Model inference uses a synthetic image (same as run_training.py) because - downloading actual GEE pixel data requires additional infrastructure. + Query Google Earth Engine for a real Sentinel-2 tile and run inference. - Falls back to outputs/inference_results.json or zeros if GEE unavailable. + Falls back to synthetic NDVI stats and a synthetic tile if GEE is + unavailable or returns no images. """ ndvi_stats: Optional[dict[str, Any]] = None gee_count: int = 0 @@ -330,51 +363,69 @@ def run_inference_from_gee( if bbox and start_date and end_date: ndvi_stats, gee_count = _try_gee_ndvi(bbox, start_date, end_date) - # --- Model inference on synthetic image (matches run_training.py) --- - model, device = _load_model() - test_image = torch.randn(1, 4, 256, 256).to(device) + # --- Attempt to download a real tile from GEE --- + try: + from climatevision.data import download_tile_for_analysis, apply_scl_cloud_mask - with torch.no_grad(): - output = model(test_image) - predictions = torch.argmax(output, dim=1) - probabilities = torch.softmax(output, dim=1) + tile_path, metadata = download_tile_for_analysis( + bbox=bbox, + start_date=start_date, + end_date=end_date, + analysis_type=analysis_type, + ) - forest_pixels = int((predictions == 1).sum().item()) - total_pixels = int(predictions.numel()) - non_forest_pixels = total_pixels - forest_pixels - forest_percentage = (forest_pixels / total_pixels) * 100 if total_pixels else 0.0 - max_probs = probabilities.max(dim=1).values - mean_confidence = float(max_probs.mean().item()) + image = _load_image_file(str(tile_path)) + + # If SCL band is present (last band), apply cloud mask and drop it + n_bands_expected = len(get_bands_for_analysis(analysis_type)) + if image.shape[0] == n_bands_expected + 1: + scl_band = image[-1].astype(np.uint8) + image = image[:-1] + image = apply_scl_cloud_mask(image, scl_band) + + result = run_inference( + image, + bbox=bbox, + start_date=start_date, + end_date=end_date, + analysis_type=analysis_type, + ) + result["metadata"] = metadata + + # Override NDVI with GEE-derived stats if we got them; else keep computed + if ndvi_stats is not None: + result["ndvi_stats"] = ndvi_stats + elif metadata.get("is_synthetic"): + result["ndvi_stats"] = _synthetic_ndvi_stats(bbox) + + if gee_count: + result["region"]["images_available"] = gee_count + + return result + + except Exception as exc: + logger.warning("Real tile inference failed (%s). Using fallback.", exc) + + # --- Fallback: template result with synthetic stats --- + result = run_inference( + np.zeros((4, 256, 256), dtype=np.float32), + bbox=bbox, + start_date=start_date, + end_date=end_date, + analysis_type=analysis_type, + ) - # Fall back to synthetic realistic NDVI when GEE is unavailable if ndvi_stats is None: - cached = _load_cached_ndvi() - # _load_cached_ndvi returns zeros when no cache exists — use synthetic instead - if all(v == 0.0 for v in cached.values()): - ndvi_stats = _synthetic_ndvi_stats(bbox) - logger.info("GEE unavailable — using synthetic NDVI stats for bbox %s", bbox) - else: - ndvi_stats = cached + ndvi_stats = _synthetic_ndvi_stats(bbox) + result["ndvi_stats"] = ndvi_stats - region: dict[str, Any] = {} - if bbox is not None: - region["bbox"] = bbox - if start_date and end_date: - region["date_range"] = f"{start_date} to {end_date}" + region = result.get("region", {}) if gee_count: region["images_available"] = gee_count + result["region"] = region + result["metadata"] = {"is_synthetic": True, "fallback_reason": "gee_tile_download_failed"} - return { - "region": region, - "ndvi_stats": ndvi_stats, - "inference": { - "image_size": [256, 256], - "forest_pixels": forest_pixels, - "non_forest_pixels": non_forest_pixels, - "forest_percentage": round(forest_percentage, 4), - "mean_confidence": round(mean_confidence, 4), - }, - } + return result def _try_gee_ndvi( From 0508417d4b6ae422b5c1fa69c8bee85df5c46e54 Mon Sep 17 00:00:00 2001 From: Olufemi Taiwo Date: Wed, 15 Apr 2026 20:49:51 +0100 Subject: [PATCH 2/2] feat(api): wire analysis_type into prediction endpoints - Pass body.analysis_type to run_inference_from_gee() in /api/predict - Pass analysis_type to run_inference_from_file() in /api/predict/upload - Enables the API to load the correct model and return correct class counts per analysis type --- src/climatevision/api/main.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/climatevision/api/main.py b/src/climatevision/api/main.py index a155ed4..ac40911 100644 --- a/src/climatevision/api/main.py +++ b/src/climatevision/api/main.py @@ -552,6 +552,7 @@ async def predict_json(body: PredictRequest) -> dict[str, Any]: bbox=body.bbox, start_date=body.start_date, end_date=body.end_date, + analysis_type=body.analysis_type, ) result_payload["analysis_type"] = body.analysis_type status = "completed" @@ -633,6 +634,7 @@ async def predict_upload( bbox=parsed_bbox, start_date=start_date, end_date=end_date, + analysis_type=analysis_type, ) result_payload["analysis_type"] = analysis_type status = "completed"