Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/climatevision/api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,7 @@ async def predict_json(body: PredictRequest) -> dict[str, Any]:
bbox=body.bbox,
start_date=body.start_date,
end_date=body.end_date,
analysis_type=body.analysis_type,
)
result_payload["analysis_type"] = body.analysis_type
status = "completed"
Expand Down Expand Up @@ -633,6 +634,7 @@ async def predict_upload(
bbox=parsed_bbox,
start_date=start_date,
end_date=end_date,
analysis_type=analysis_type,
)
result_payload["analysis_type"] = analysis_type
status = "completed"
Expand Down
219 changes: 135 additions & 84 deletions src/climatevision/inference/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
Inference pipeline for ClimateVision.

Provides:
- run_inference(image_array, bbox, start_date, end_date) — core inference on a numpy array
- run_inference_from_file(path, bbox, start_date, end_date) — load file then infer
- run_inference_from_gee(bbox, start_date, end_date) — GEE NDVI + synthetic model inference
- run_inference(image_array, bbox, start_date, end_date, analysis_type) — core inference on a numpy array
- run_inference_from_file(path, bbox, start_date, end_date, analysis_type) — load file then infer
- run_inference_from_gee(bbox, start_date, end_date, analysis_type) — GEE NDVI + real tile inference
"""

from __future__ import annotations
Expand All @@ -17,6 +17,7 @@
import numpy as np
import torch

from climatevision.data.band_mapping import get_bands_for_analysis, get_model_config
from climatevision.models.unet import UNet

logger = logging.getLogger(__name__)
Expand All @@ -29,10 +30,9 @@
_OUTPUTS_DIR = _PROJECT_ROOT / "outputs"

# ---------------------------------------------------------------------------
# Singleton model cache
# Per-analysis-type model cache
# ---------------------------------------------------------------------------
_cached_model: Optional[UNet] = None
_cached_device: Optional[torch.device] = None
_model_cache: dict[str, tuple[UNet, torch.device]] = {}


def _get_device() -> torch.device:
Expand All @@ -41,11 +41,18 @@ def _get_device() -> torch.device:
return torch.device("cpu")


def _find_best_checkpoint() -> Optional[Path]:
def _find_best_checkpoint(analysis_type: str) -> Optional[Path]:
"""
Search for the best available checkpoint.
Priority: models/best_model.pth > newest models/*/best_model.pth
Search for the best available checkpoint for an analysis type.
Priority: config.yaml weight path > models/best_model.pth > newest models/*/best_model.pth
"""
model_cfg = get_model_config(analysis_type)
config_path = model_cfg.get("weights")
if config_path:
p = _PROJECT_ROOT / config_path
if p.exists():
return p

direct = _MODELS_DIR / "best_model.pth"
if direct.exists():
return direct
Expand All @@ -57,17 +64,21 @@ def _find_best_checkpoint() -> Optional[Path]:
return candidates[0] if candidates else None


def _load_model() -> tuple[UNet, torch.device]:
"""Load (or return cached) U-Net model."""
global _cached_model, _cached_device
def _load_model(analysis_type: str = "deforestation") -> tuple[UNet, torch.device]:
"""Load (or return cached) U-Net model configured for the analysis type."""
global _model_cache

if _cached_model is not None and _cached_device is not None:
return _cached_model, _cached_device
if analysis_type in _model_cache:
return _model_cache[analysis_type]

device = _get_device()
model = UNet(n_channels=4, n_classes=2)
model_cfg = get_model_config(analysis_type)
n_channels = model_cfg.get("in_channels", 4)
n_classes = model_cfg.get("num_classes", 2)

model = UNet(n_channels=n_channels, n_classes=n_classes)

model_path = _find_best_checkpoint()
model_path = _find_best_checkpoint(analysis_type)
if model_path is not None:
checkpoint = torch.load(model_path, map_location=device)

Expand All @@ -85,21 +96,23 @@ def _load_model() -> tuple[UNet, torch.device]:
param.data.copy_(ema_state[name])

logger.info(
"Loaded model from %s (epoch %s val_iou %.4f)",
"Loaded %s model from %s (epoch %s val_iou %.4f)",
analysis_type,
model_path,
checkpoint.get("epoch", "?"),
checkpoint.get("val_iou", 0.0),
)
else:
logger.warning(
"No trained model found under %s — using untrained weights (demo).", _MODELS_DIR
"No trained model found for %s under %s — using untrained weights (demo).",
analysis_type,
_MODELS_DIR,
)

model = model.to(device)
model.eval()

_cached_model = model
_cached_device = device
_model_cache[analysis_type] = (model, device)
return model, device


Expand Down Expand Up @@ -193,6 +206,7 @@ def run_inference(
bbox: Optional[list[float]] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
analysis_type: str = "deforestation",
) -> dict[str, Any]:
"""
Run full inference pipeline on a (C, H, W) numpy image.
Expand All @@ -205,34 +219,54 @@ def run_inference(

ndvi_stats = _compute_ndvi_stats(image)

# Prepare tensor — model expects (N, 4, H, W)
model, device = _load_model(analysis_type)
n_channels = model.n_channels
n_classes = model.n_classes

# Prepare tensor — model expects (N, n_channels, H, W)
c, h, w = image.shape
if c < 4:
if c < n_channels:
# Pad missing channels with zeros
pad = np.zeros((4 - c, h, w), dtype=image.dtype)
pad = np.zeros((n_channels - c, h, w), dtype=image.dtype)
image = np.concatenate([image, pad], axis=0)
elif c > 4:
image = image[:4]
elif c > n_channels:
image = image[:n_channels]

# Use torch.FloatTensor via tolist() to avoid numpy<->torch interop issues
tensor = torch.FloatTensor(image.astype(np.float32).tolist()).unsqueeze(0) # (1, 4, H, W)

model, device = _load_model()
tensor = torch.FloatTensor(image.astype(np.float32).tolist()).unsqueeze(0) # (1, C, H, W)
tensor = tensor.to(device)

with torch.no_grad():
output = model(tensor)
predictions = torch.argmax(output, dim=1) # (1, H, W)
probabilities = torch.softmax(output, dim=1) # (1, 2, H, W)
probabilities = torch.softmax(output, dim=1) # (1, n_classes, H, W)

forest_pixels = int((predictions == 1).sum().item())
total_pixels = int(predictions.numel())
non_forest_pixels = total_pixels - forest_pixels
forest_percentage = (forest_pixels / total_pixels) * 100 if total_pixels else 0.0

max_probs = probabilities.max(dim=1).values
mean_confidence = float(max_probs.mean().item())

# Build per-class pixel counts
class_pixels: dict[str, int] = {}
class_percentages: dict[str, float] = {}
for cls in range(n_classes):
count = int((predictions == cls).sum().item())
pct = (count / total_pixels) * 100 if total_pixels else 0.0
class_pixels[f"class_{cls}_pixels"] = count
class_percentages[f"class_{cls}_percentage"] = round(pct, 4)

# Add friendly keys for known 2-class deforestation output (backward compat)
inference: dict[str, Any] = {
"image_size": [h, w],
"num_classes": n_classes,
"mean_confidence": round(mean_confidence, 4),
**class_pixels,
**class_percentages,
}
if n_classes == 2:
inference["forest_pixels"] = class_pixels.get("class_1_pixels", 0)
inference["non_forest_pixels"] = class_pixels.get("class_0_pixels", 0)
inference["forest_percentage"] = class_percentages.get("class_1_percentage", 0.0)

region: dict[str, Any] = {}
if bbox is not None:
region["bbox"] = bbox
Expand All @@ -242,13 +276,7 @@ def run_inference(
return {
"region": region,
"ndvi_stats": ndvi_stats,
"inference": {
"image_size": [h, w],
"forest_pixels": forest_pixels,
"non_forest_pixels": non_forest_pixels,
"forest_percentage": round(forest_percentage, 4),
"mean_confidence": round(mean_confidence, 4),
},
"inference": inference,
}


Expand All @@ -262,12 +290,19 @@ def run_inference_from_file(
bbox: Optional[list[float]] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
analysis_type: str = "deforestation",
) -> dict[str, Any]:
"""
Load an image file (GeoTIFF or PNG/JPEG) and run inference.
"""
image = _load_image_file(path)
result = run_inference(image, bbox=bbox, start_date=start_date, end_date=end_date)
result = run_inference(
image,
bbox=bbox,
start_date=start_date,
end_date=end_date,
analysis_type=analysis_type,
)
result.setdefault("input", {})["file"] = path
return result

Expand Down Expand Up @@ -314,67 +349,83 @@ def run_inference_from_gee(
bbox: Optional[list[float]] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
analysis_type: str = "deforestation",
) -> dict[str, Any]:
"""
Query Google Earth Engine for NDVI stats and run model on synthetic data.

GEE provides real NDVI statistics computed server-side.
Model inference uses a synthetic image (same as run_training.py) because
downloading actual GEE pixel data requires additional infrastructure.
Query Google Earth Engine for a real Sentinel-2 tile and run inference.

Falls back to outputs/inference_results.json or zeros if GEE unavailable.
Falls back to synthetic NDVI stats and a synthetic tile if GEE is
unavailable or returns no images.
"""
ndvi_stats: Optional[dict[str, Any]] = None
gee_count: int = 0

if bbox and start_date and end_date:
ndvi_stats, gee_count = _try_gee_ndvi(bbox, start_date, end_date)

# --- Model inference on synthetic image (matches run_training.py) ---
model, device = _load_model()
test_image = torch.randn(1, 4, 256, 256).to(device)
# --- Attempt to download a real tile from GEE ---
try:
from climatevision.data import download_tile_for_analysis, apply_scl_cloud_mask

with torch.no_grad():
output = model(test_image)
predictions = torch.argmax(output, dim=1)
probabilities = torch.softmax(output, dim=1)
tile_path, metadata = download_tile_for_analysis(
bbox=bbox,
start_date=start_date,
end_date=end_date,
analysis_type=analysis_type,
)

forest_pixels = int((predictions == 1).sum().item())
total_pixels = int(predictions.numel())
non_forest_pixels = total_pixels - forest_pixels
forest_percentage = (forest_pixels / total_pixels) * 100 if total_pixels else 0.0
max_probs = probabilities.max(dim=1).values
mean_confidence = float(max_probs.mean().item())
image = _load_image_file(str(tile_path))

# If SCL band is present (last band), apply cloud mask and drop it
n_bands_expected = len(get_bands_for_analysis(analysis_type))
if image.shape[0] == n_bands_expected + 1:
scl_band = image[-1].astype(np.uint8)
image = image[:-1]
image = apply_scl_cloud_mask(image, scl_band)

result = run_inference(
image,
bbox=bbox,
start_date=start_date,
end_date=end_date,
analysis_type=analysis_type,
)
result["metadata"] = metadata

# Override NDVI with GEE-derived stats if we got them; else keep computed
if ndvi_stats is not None:
result["ndvi_stats"] = ndvi_stats
elif metadata.get("is_synthetic"):
result["ndvi_stats"] = _synthetic_ndvi_stats(bbox)

if gee_count:
result["region"]["images_available"] = gee_count

return result

except Exception as exc:
logger.warning("Real tile inference failed (%s). Using fallback.", exc)

# --- Fallback: template result with synthetic stats ---
result = run_inference(
np.zeros((4, 256, 256), dtype=np.float32),
bbox=bbox,
start_date=start_date,
end_date=end_date,
analysis_type=analysis_type,
)

# Fall back to synthetic realistic NDVI when GEE is unavailable
if ndvi_stats is None:
cached = _load_cached_ndvi()
# _load_cached_ndvi returns zeros when no cache exists — use synthetic instead
if all(v == 0.0 for v in cached.values()):
ndvi_stats = _synthetic_ndvi_stats(bbox)
logger.info("GEE unavailable — using synthetic NDVI stats for bbox %s", bbox)
else:
ndvi_stats = cached
ndvi_stats = _synthetic_ndvi_stats(bbox)
result["ndvi_stats"] = ndvi_stats

region: dict[str, Any] = {}
if bbox is not None:
region["bbox"] = bbox
if start_date and end_date:
region["date_range"] = f"{start_date} to {end_date}"
region = result.get("region", {})
if gee_count:
region["images_available"] = gee_count
result["region"] = region
result["metadata"] = {"is_synthetic": True, "fallback_reason": "gee_tile_download_failed"}

return {
"region": region,
"ndvi_stats": ndvi_stats,
"inference": {
"image_size": [256, 256],
"forest_pixels": forest_pixels,
"non_forest_pixels": non_forest_pixels,
"forest_percentage": round(forest_percentage, 4),
"mean_confidence": round(mean_confidence, 4),
},
}
return result


def _try_gee_ndvi(
Expand Down