## 1. Setup
Install the updated requirements (scipy, pyproj, matplotlib, folium) if you haven't already.

In [None]:
from __future__ import annotations

import json
from dataclasses import dataclass
from pathlib import Path
from typing import Tuple

import folium
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from dotenv import load_dotenv
from matplotlib import cm, colors
from pyproj import Transformer
from scipy.spatial import cKDTree
import sqlalchemy as sa

load_dotenv(Path('..') / '.env')

BASE_DIR = Path('..').resolve()
PROCESSED_DIR = BASE_DIR / 'data' / 'processed'
PROCESSED_DIR.mkdir(parents=True, exist_ok=True)

print('Base dir:', BASE_DIR)


## 2. Fetch a snapshot from `clean_measurements`
Choose a timestamp (UTC) to grid. By default we grab the most recent one.

In [None]:
from datetime import timedelta

def ensure_utc(value) -> pd.Timestamp:
    ts = pd.Timestamp(value)
    if ts.tzinfo is None:
        return ts.tz_localize('UTC')
    return ts.tz_convert('UTC')


def load_clean_snapshot(target_ts=None, window_minutes=5):
    ts = ensure_utc(target_ts if target_ts is not None else latest_ts)
    start = ts - pd.Timedelta(minutes=window_minutes)
    end = ts + pd.Timedelta(minutes=window_minutes)
    query = sa.text(
        'SELECT cm.sensor_id, cm.ts, cm.value_mm, cm.imputation_method, '
        's.lat, s.lon '
        'FROM clean_measurements cm '
        'JOIN sensors s ON s.id = cm.sensor_id '
        'WHERE cm.ts BETWEEN :start AND :end'
    )
    df = pd.read_sql(query, engine, params={'start': start, 'end': end})
    df['ts'] = pd.to_datetime(df['ts'], utc=True)
    return df.sort_values('sensor_id')

snapshot_df = load_clean_snapshot()
print(snapshot_df.shape)
snapshot_df.head()


In [None]:
from datetime import timedelta

def ensure_utc(value) -> pd.Timestamp:
    ts = pd.Timestamp(value)
    if ts.tzinfo is None:
        return ts.tz_localize('UTC')
    return ts.tz_convert('UTC')


def load_clean_snapshot(target_ts=None, window_minutes=5):
    ts = ensure_utc(target_ts if target_ts is not None else latest_ts)
    start = ts - pd.Timedelta(minutes=window_minutes)
    end = ts + pd.Timedelta(minutes=window_minutes)
    query = sa.text(
        'SELECT cm.sensor_id, cm.ts, cm.value_mm, cm.imputation_method, '
        's.lat, s.lon '
        'FROM clean_measurements cm '
        'JOIN sensors s ON s.id = cm.sensor_id '
        'WHERE cm.ts BETWEEN :start AND :end'
    )
    df = pd.read_sql(query, engine, params={'start': start, 'end': end})
    df['ts'] = pd.to_datetime(df['ts'], utc=True)
    return df.sort_values('sensor_id')

snapshot_df = load_clean_snapshot()
print(snapshot_df.shape)
snapshot_df.head()


high_median_row = pd.read_sql(
    sa.text('''
        SELECT ts,
               percentile_cont(0.5) WITHIN GROUP (ORDER BY value_mm) AS median_mm
        FROM clean_measurements
        GROUP BY ts
        ORDER BY median_mm DESC
        LIMIT 1
    '''),
    engine,
)

median_ts = ensure_utc(high_median_row.at[0, 'ts'])
median_value = high_median_row.at[0, 'median_mm']

print('Timestamp with highest median precipitation (UTC):', median_ts)
print('Median precipitation at that time:', round(median_value, 2), 'mm')

high_median_snapshot = load_clean_snapshot(target_ts=median_ts, window_minutes=5)
print('Sensors in window:', len(high_median_snapshot))
high_median_snapshot[['sensor_id', 'value_mm']].head()


high_median_row = pd.read_sql(
    sa.text('''
        SELECT ts,
               percentile_cont(0.5) WITHIN GROUP (ORDER BY value_mm) AS median_mm
        FROM clean_measurements
        GROUP BY ts
        ORDER BY median_mm DESC
        LIMIT 1
    '''),
    engine,
)

median_ts = ensure_utc(high_median_row.at[0, 'ts'])
median_value = high_median_row.at[0, 'median_mm']

print('Timestamp with highest median precipitation (UTC):', median_ts)
print('Median precipitation at that time:', round(median_value, 2), 'mm')

high_median_snapshot = load_clean_snapshot(target_ts=median_ts, window_minutes=5)
print('Sensors in window:', len(high_median_snapshot))
high_median_snapshot[['sensor_id', 'value_mm']].head()


In [None]:
high_median_row = pd.read_sql(
    sa.text('''
        SELECT ts AT TIME ZONE 'UTC' AS ts_utc,
               percentile_cont(0.5) WITHIN GROUP (ORDER BY value_mm) AS median_mm
        FROM clean_measurements
        GROUP BY ts
        ORDER BY median_mm DESC
        LIMIT 1
    '''),
    engine,
)

median_ts = pd.to_datetime(high_median_row.at[0, 'ts_utc'], utc=True)
median_value = high_median_row.at[0, 'median_mm']

print('Timestamp with highest median precipitation (UTC):', median_ts)
print('Median precipitation at that time:', round(median_value, 2), 'mm')

high_median_snapshot = load_clean_snapshot(target_ts=median_ts, window_minutes=5)
print('Sensors in window:', len(high_median_snapshot))
high_median_snapshot[['sensor_id', 'value_mm']].head()


In [None]:
snapshot_df = high_median_snapshot.copy()
print('Using high-median snapshot:', snapshot_df['ts'].max())


In [None]:
transformer_to_3857 = Transformer.from_crs('EPSG:4326', 'EPSG:3857', always_xy=True)
transformer_to_wgs84 = Transformer.from_crs('EPSG:3857', 'EPSG:4326', always_xy=True)

snapshot_df['x'], snapshot_df['y'] = transformer_to_3857.transform(
    snapshot_df['lon'].to_numpy(),
    snapshot_df['lat'].to_numpy()
)

bbox_padding = 2000  # metres
res_m = 500

min_x = snapshot_df['x'].min() - bbox_padding
max_x = snapshot_df['x'].max() + bbox_padding
min_y = snapshot_df['y'].min() - bbox_padding
max_y = snapshot_df['y'].max() + bbox_padding

nx = int(np.ceil((max_x - min_x) / res_m)) + 1
ny = int(np.ceil((max_y - min_y) / res_m)) + 1

x_grid = np.linspace(min_x, max_x, nx)
y_grid = np.linspace(min_y, max_y, ny)

xx, yy = np.meshgrid(x_grid, y_grid)

print('Grid shape:', yy.shape, 'resolution (m):', res_m)


### Seed the grid with nearest sensor values

In [None]:
seed_grid = np.full((ny, nx), np.nan, dtype=float)
mask = np.zeros_like(seed_grid, dtype=bool)

sensor_values = snapshot_df['value_mm'].to_numpy(dtype=float)
xi = ((snapshot_df['x'] - min_x) / res_m).round().astype(int)
yi = ((snapshot_df['y'] - min_y) / res_m).round().astype(int)

for i in range(len(snapshot_df)):
    gx = np.clip(xi[i], 0, nx - 1)
    gy = np.clip(yi[i], 0, ny - 1)
    if mask[gy, gx]:
        seed_grid[gy, gx] = np.nanmean([seed_grid[gy, gx], sensor_values[i]])
    else:
        seed_grid[gy, gx] = sensor_values[i]
        mask[gy, gx] = True

np.count_nonzero(mask), seed_grid.shape


## 4. Lanczos-4 interpolation

In [None]:
def lanczos_kernel(radius: int, a: int = 4) -> np.ndarray:
    x = np.arange(-radius, radius + 1, dtype=float)
    def lanczos(x):
        x = np.asarray(x, dtype=float)
        out = np.sinc(x) * np.sinc(x / a)
        out[np.abs(x) > a] = 0.0
        out[np.isnan(out)] = 1.0
        return out
    k1d = lanczos(x)
    kernel = np.outer(k1d, k1d)
    return kernel / kernel.sum()

kernel = lanczos_kernel(radius=4, a=4)
from scipy.signal import convolve2d

seed_values = np.nan_to_num(seed_grid, nan=0.0)
seed_mask = (~np.isnan(seed_grid)).astype(float)

num = convolve2d(seed_values, kernel, mode='same', boundary='symm')
den = convolve2d(seed_mask, kernel, mode='same', boundary='symm')

lanczos_grid = np.where(den > 0, num / den, np.nan)
np.sum(~np.isnan(lanczos_grid))


### 4.1 Fill remaining gaps via nearest neighbour

In [None]:
remaining = np.isnan(lanczos_grid)
if remaining.any():
    valid_idx = np.column_stack(np.where(~remaining))
    valid_vals = lanczos_grid[~remaining]
    tree = cKDTree(valid_idx)
    target_idx = np.column_stack(np.where(remaining))
    dist, nn = tree.query(target_idx)
    lanczos_grid[remaining] = valid_vals[nn]

np.sum(np.isnan(lanczos_grid))


## 5. Visualise grid and contours

In [None]:
levels = np.linspace(np.nanmin(lanczos_grid), np.nanmax(lanczos_grid), 12)
fig, ax = plt.subplots(figsize=(10, 6))
mesh = ax.pcolormesh(x_grid, y_grid, lanczos_grid, cmap='viridis', shading='auto')
contours = ax.contour(x_grid, y_grid, lanczos_grid, levels=levels, colors='white', linewidths=0.7)
ax.clabel(contours, inline=True, fontsize=8, fmt='%.1f')
ax.scatter(snapshot_df['x'], snapshot_df['y'], c='red', s=10, label='Sensors')
ax.set_title('Lanczos Interpolated Grid (EPSG:3857)')
ax.set_xlabel('x (m)')
ax.set_ylabel('y (m)')
ax.legend(loc='upper right')
fig.colorbar(mesh, ax=ax, label='Precipitation (mm)')
plt.tight_layout()


## 6. Prepare artifacts (npz + metadata)

In [None]:
grid_metadata = {
    'timestamp': snapshot_df['ts'].max().isoformat(),
    'res_m': res_m,
    'bbox_3857': [float(min_x), float(min_y), float(max_x), float(max_y)],
}

# Convert bbox to WGS84 (min/max)
west, north = transformer_to_wgs84.transform(min_x, max_y)
east, south = transformer_to_wgs84.transform(max_x, min_y)
grid_metadata['bbox_wgs84'] = [west, south, east, north]

grid_npz_path = PROCESSED_DIR / f"grid_{snapshot_df['ts'].max().strftime('%Y%m%dT%H%M%SZ')}.npz"
np.savez_compressed(
    grid_npz_path,
    data=lanczos_grid.astype(np.float32),
    x=x_grid.astype(np.float64),
    y=y_grid.astype(np.float64),
    metadata=json.dumps(grid_metadata),
)

print('Saved grid to', grid_npz_path)
print('Metadata:', grid_metadata)


## 7. Interactive map preview (Folium)

In [None]:
norm = colors.Normalize(vmin=np.nanmin(lanczos_grid), vmax=np.nanmax(lanczos_grid))
rgba = cm.get_cmap('viridis')(norm(lanczos_grid))
rgba[..., 3] = np.where(np.isnan(lanczos_grid), 0.0, 0.75)

west, north = transformer_to_wgs84.transform(min_x, max_y)
east, south = transformer_to_wgs84.transform(max_x, min_y)
bounds = [[south, west], [north, east]]

center_lat = snapshot_df['lat'].mean()
center_lon = snapshot_df['lon'].mean()

m = folium.Map(location=[center_lat, center_lon], zoom_start=11, tiles='CartoDB positron')
folium.raster_layers.ImageOverlay(
    image=rgba,
    bounds=bounds,
    opacity=0.7,
    origin='upper',
    name='Interpolated Grid',
).add_to(m)

for _, row in snapshot_df.iterrows():
    folium.CircleMarker(
        location=[row['lat'], row['lon']],
        radius=3,
        color='red',
        fill=True,
        fill_opacity=0.9,
        popup=f"{row['sensor_id']}<br>value: {row['value_mm']:.2f} mm",
    ).add_to(m)

folium.LayerControl().add_to(m)

m
