# Feature Extraction Notebook

**Ziel:** Für jeden Baum im gefilterten Baumkataster (edge_15m) folgende Features extrahieren:
- CHM-Features (4): height_m (Snap-to-Peak aus Kataster), CHM_mean, CHM_max, CHM_std
- S2-Features (180): 10 Bänder + 5 Indizes × 12 Monate

**Input:**
- Baumkataster: `data/tree_cadastres/processed/trees_filtered_viable_edge_15m.gpkg`
- Sentinel-2: `data/sentinel2/{city}/S2_2021_{month:02d}_median.tif` (Jan-Dez)
- CHM 10m: `data/CHM/processed/CHM_10m/CHM_10m_{variant}_{city}.tif` (mean, max, std)

**Output:**
- `data/features/trees_with_features_{city}.gpkg` (Berlin, Hamburg, Rostock)

**Feature-Struktur:**
- Baum-Attribute: tree_id, city, genus_latin, species_latin, geometry
- CHM-Features (4): height_m, CHM_mean, CHM_max, CHM_std
- S2-Features (180): B02_01 ... RTVIcore_12
- Total: 184 Features + Metadaten

**NoData-Handling:**
- Entferne Bäume mit >3 Monaten NoData in S2
- Interpoliere 1-3 Monate NoData (lineare zeitliche Interpolation)

## 1. Setup & Imports

In [1]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
# Install necessary libraries
!pip install geopandas rasterio --quiet

import os
import numpy as np
import pandas as pd
import geopandas as gpd
import rasterio
from rasterio.transform import rowcol
from pathlib import Path
from tqdm.notebook import tqdm
import warnings
warnings.filterwarnings('ignore')

print("✅ Imports and pip installs successful")

✅ Imports and pip installs successful


## 2. Configuration

In [3]:
# Base paths
BASE_DIR = Path("/content/drive/MyDrive/Studium/Geoinformation/Module/Projektarbeit")
DATA_DIR = BASE_DIR / "data"

# Input paths
CADASTRE_DIR = DATA_DIR / "tree_cadastres" / "corrected" / "processed"
S2_DIR = DATA_DIR / "sentinel2_2021"
CHM_DIR = DATA_DIR / "CHM" / "processed" / "CHM_10m"
OUTPUT_DIR = DATA_DIR / "features"

# Create output directory
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# Cities
CITIES = ["Berlin", "Hamburg", "Rostock"]

# Sentinel-2 bands (10 spectral + 5 indices)
S2_BANDS = [
    "B02", "B03", "B04", "B05", "B06", "B07", "B08", "B8A", "B11", "B12",  # Spectral
    "NDre", "NDVIre", "kNDVI", "VARI", "RTVIcore"  # Indices
]

# Months (Jan-Dec)
MONTHS = list(range(1, 13))

# CHM variants
CHM_VARIANTS = ["mean", "max", "std"]

# NoData threshold (months with NoData)
MAX_NODATA_MONTHS = 3

print("✅ Configuration loaded")
print(f"  Base directory: {BASE_DIR}")
print(f"  Cadastre directory: {CADASTRE_DIR}")
print(f"  Cities: {CITIES}")
print(f"  S2 Bands: {len(S2_BANDS)} ({S2_BANDS[:3]}...{S2_BANDS[-2:]})")
print(f"  Months: {len(MONTHS)}")
print(f"  CHM Variants: {CHM_VARIANTS}")
print(f"  Max NoData Months: {MAX_NODATA_MONTHS}")

✅ Configuration loaded
  Base directory: /content/drive/MyDrive/Studium/Geoinformation/Module/Projektarbeit
  Cadastre directory: /content/drive/MyDrive/Studium/Geoinformation/Module/Projektarbeit/data/tree_cadastres/corrected/processed
  Cities: ['Berlin', 'Hamburg', 'Rostock']
  S2 Bands: 15 (['B02', 'B03', 'B04']...['VARI', 'RTVIcore'])
  Months: 12
  CHM Variants: ['mean', 'max', 'std']
  Max NoData Months: 3


## 3. Load Tree Cadastre

In [6]:
# ============================================================================
# 3. LOAD & PRE-FILTER TREE CADASTRE
# ============================================================================

# Load individual city cadastres and concatenate
al_trees = []
for city in CITIES:
    cadastre_path = CADASTRE_DIR / f"trees_corrected_{city}.gpkg"
    if cadastre_path.exists():
        print(f"  Loading {city} trees from {cadastre_path.name}")
        city_trees_gdf = gpd.read_file(cadastre_path)
        city_trees_gdf['city'] = city # Ensure city column is set correctly if not already present
        al_trees.append(city_trees_gdf)
    else:
        print(f"  Warning: Cadastre file not found for {city}: {cadastre_path.name}")

if not al_trees:
    raise FileNotFoundError("No tree cadastre files were loaded. Please check paths.")

trees_gdf = pd.concat(al_trees, ignore_index=True)

print(f"✅ Loaded {len(trees_gdf):,} trees in total")
print(f"  Cities: {trees_gdf['city'].value_counts().to_dict()}")
print(f"  Genera: {trees_gdf['genus_latin'].nunique()}")
print(f"  Columns: {list(trees_gdf.columns)}")

# Ensure required columns exist
required_cols = ["tree_id", "city", "genus_latin", "species_latin", "height_m", "geometry"]
missing_cols = [col for col in required_cols if col not in trees_gdf.columns]
if missing_cols:
    raise ValueError(f"Missing required columns: {missing_cols}")

print("✅ All required columns present")

print("\n--- Pre-Filtering ---")

# Store original count
trees_original = len(trees_gdf)

# ========================================================================
# FILTER 1: height_m must be valid (Snap-to-Peak validation)
# ========================================================================

trees_no_height = trees_gdf["height_m"].isna().sum()
print(f"  Filter 1 (height_m NoData): {trees_no_height:,} trees ({trees_no_height/trees_original*100:.1f}%)")

trees_gdf = trees_gdf[trees_gdf["height_m"].notna()].copy()

# ========================================================================
# FILTER 2: genus_latin must be valid (no label = no training)
# ========================================================================

trees_no_genus = trees_gdf["genus_latin"].isna().sum()
print(f"  Filter 2 (genus_latin NoData): {trees_no_genus:,} trees ({trees_no_genus/trees_original*100:.1f}%)")

trees_gdf = trees_gdf[trees_gdf["genus_latin"].notna()].copy()

# ========================================================================
# FILTER 3: plant_year <= 2021 (CHM reference year)
# ========================================================================

CHM_REFERENCE_YEAR = 2021

# Check if plant_year column exists
if "plant_year" in trees_gdf.columns:
    # Trees with plant_year > 2021
    trees_too_young = (trees_gdf["plant_year"] > CHM_REFERENCE_YEAR).sum()
    print(f"  Filter 3 (plant_year > {CHM_REFERENCE_YEAR}): {trees_too_young:,} trees ({trees_too_young/trees_original*100:.2f}%)")

    # Filter: Keep only trees with plant_year <= 2021 OR plant_year = NaN
    trees_gdf = trees_gdf[
        (trees_gdf["plant_year"].isna()) | (trees_gdf["plant_year"] <= CHM_REFERENCE_YEAR)
    ].copy()

    # Optional: Report NaN plant_years (for transparency)
    trees_unknown_age = trees_gdf["plant_year"].isna().sum()
    print(f"    → Trees with unknown plant_year (kept): {trees_unknown_age:,} ({trees_unknown_age/len(trees_gdf)*100:.1f}%)")
else:
    print(f"  Filter 3 (plant_year): Column not found, skipping")

# ========================================================================
# CLEANUP: Remove unnecessary columns
# ========================================================================

# Columns to drop (if they exist)
cols_to_drop = ["plant_year", "snap_distance"]  # Add more if needed

existing_cols_to_drop = [col for col in cols_to_drop if col in trees_gdf.columns]

if existing_cols_to_drop:
    print(f"\n  Dropping columns: {existing_cols_to_drop}")
    trees_gdf = trees_gdf.drop(columns=existing_cols_to_drop)

# ========================================================================
# SUMMARY
# ========================================================================

trees_final = len(trees_gdf)
trees_removed = trees_original - trees_final

print(f"\n✅ Pre-filtering complete:")
print(f"  Original: {trees_original:,} trees")
print(f"  Removed: {trees_removed:,} trees ({trees_removed/trees_original*100:.1f}%)")
print(f"  Remaining: {trees_final:,} trees")
print(f"  Cities: {trees_gdf['city'].value_counts().to_dict()}")
print(f"  Genera: {trees_gdf['genus_latin'].nunique()}")

  Loading Berlin trees from trees_corrected_Berlin.gpkg
  Loading Hamburg trees from trees_corrected_Hamburg.gpkg
  Loading Rostock trees from trees_corrected_Rostock.gpkg
✅ Loaded 315,977 trees in total
  Cities: {'Berlin': 219900, 'Hamburg': 78577, 'Rostock': 17500}
  Genera: 8
  Columns: ['tree_id', 'city', 'genus_latin', 'species_latin', 'plant_year', 'height_m', 'snap_distance_m', 'geometry']
✅ All required columns present

--- Pre-Filtering ---
  Filter 1 (height_m NoData): 0 trees (0.0%)
  Filter 2 (genus_latin NoData): 0 trees (0.0%)
  Filter 3 (plant_year > 2021): 0 trees (0.00%)
    → Trees with unknown plant_year (kept): 32,415 (10.3%)

  Dropping columns: ['plant_year']

✅ Pre-filtering complete:
  Original: 315,977 trees
  Removed: 0 trees (0.0%)
  Remaining: 315,977 trees
  Cities: {'Berlin': 219900, 'Hamburg': 78577, 'Rostock': 17500}
  Genera: 8


## 4. Helper Functions

In [14]:
def extract_features_vectorized(city_gdf, city, s2_dir, chm_dir, s2_bands, months, chm_variants):
    coords = [(geom.x, geom.y) for geom in city_gdf.geometry]
    df_res = pd.DataFrame({
        "tree_id": city_gdf["tree_id"],
        "city": city_gdf["city"],
        "genus_latin": city_gdf["genus_latin"],
        "species_latin": city_gdf["species_latin"],
        "height_m": city_gdf["height_m"]
    })

    # --- CHM EXTRAKTION ---
    for variant in chm_variants:
        chm_file = chm_dir / f"CHM_10m_{variant}_{city}.tif"
        if chm_file.exists():
            with rasterio.open(chm_file) as src:
                sampled = np.array([val[0] for val in src.sample(coords)])
                if src.nodata is not None:
                    sampled = np.where(sampled == src.nodata, np.nan, sampled)
                df_res[f"CHM_{variant}"] = sampled
        else:
            df_res[f"CHM_{variant}"] = np.nan

    # --- SENTINEL-2 EXTRAKTION ---
    print(f"  [S2] Sampling {len(months)} months...")
    for month in tqdm(months, desc="Months", leave=False):
        s2_file = s2_dir / f"S2_{city}_2021_{month:02d}_median.tif"
        if s2_file.exists():
            with rasterio.open(s2_file) as src:
                sampled_all_bands = np.array(list(src.sample(coords)))
                for b_idx, band_name in enumerate(s2_bands):
                    if b_idx < src.count:
                        band_data = sampled_all_bands[:, b_idx].astype(float)
                        if src.nodata is not None:
                            band_data[band_data == src.nodata] = np.nan
                        df_res[f"{band_name}_{month:02d}"] = band_data
        else:
            for band_name in s2_bands:
                df_res[f"{band_name}_{month:02d}"] = np.nan
    return df_res

def apply_vectorized_nodata_logic(df, s2_bands, months, max_nodata_months):
    """
    Zählt NoData-Monate basierend auf B02 und interpoliert Lücken.
    Gibt den DataFrame und die Maske der gültigen Bäume zurück.
    """
    ref_cols = [f"B02_{m:02d}" for m in months if f"B02_{m:02d}" in df.columns]

    # WICHTIG: Die Spalte heißt hier 'nodata_months' für den Hauptblock
    df["nodata_months"] = df[ref_cols].isna().sum(axis=1)

    mask_valid = df["nodata_months"] <= max_nodata_months
    mask_to_interpolate = (df["nodata_months"] > 0) & mask_valid

    if mask_to_interpolate.any():
        s2_all_cols = [f"{b}_{m:02d}" for b in s2_bands for m in months]
        available_cols = [c for c in s2_all_cols if c in df.columns]
        df.loc[mask_to_interpolate, available_cols] = df.loc[mask_to_interpolate, available_cols].interpolate(
            axis=1, limit_direction='both', method='linear'
        )

    return df, mask_valid

## 5. Feature Extraction (Per City)

In [15]:
print("="*60)
print("FEATURE EXTRACTION - FULL OPTIMIZED VERSION")
print("="*60)

all_results = []

for city in CITIES:
    print(f"\n--- Processing {city} ---")

    # 1. Vorbereitung
    city_trees = trees_gdf[trees_gdf["city"] == city].copy()
    n_original = len(city_trees)
    print(f"  Trees in Cadastre: {n_original:,}")

    # 2. Batch-Extraktion (Raster-Werte lesen)
    city_features_df = extract_features_vectorized(
        city_trees, city, S2_DIR, CHM_DIR, S2_BANDS, MONTHS, CHM_VARIANTS
    )

    # 3. CHM NoData Check & Filter
    print(f"\n  CHM NoData Check:")
    chm_cols = ["CHM_mean", "CHM_max", "CHM_std"]
    mask_chm_nodata = city_features_df[chm_cols].isna().any(axis=1)
    n_removed_chm = mask_chm_nodata.sum()

    city_features_df = city_features_df[~mask_chm_nodata].copy()
    print(f"    ✅ Removed {n_removed_chm:,} trees with CHM NoData")
    print(f"    Remaining: {len(city_features_df):,} trees")

    # 4. Sentinel-2 NoData Handling (Vektorbasiert)
    print(f"\n  NoData Handling (Sentinel-2):")
    city_features_df, mask_valid = apply_vectorized_nodata_logic(
        city_features_df, S2_BANDS, MONTHS, MAX_NODATA_MONTHS
    )

    n_removed_s2 = (~mask_valid).sum()
    n_interpolated = ((city_features_df["nodata_months"] > 0) & mask_valid).sum()

    print(f"    Trees with >{MAX_NODATA_MONTHS} months NoData: {n_removed_s2:,} ({n_removed_s2/len(city_features_df)*100:.1f}%)")
    print(f"    Trees interpolated (1-{MAX_NODATA_MONTHS} months): {n_interpolated:,}")

    # Filtern und Hilfsspalte entfernen
    city_features_df = city_features_df[mask_valid].drop(columns=["nodata_months"])
    print(f"    ✅ Filtered and Interpolated. Remaining: {len(city_features_df):,} trees")

    # 5. Validierung & Statistiken
    s2_cols = [f"{b}_{m:02d}" for b in S2_BANDS for m in MONTHS]
    feature_cols = s2_cols + chm_cols + ["height_m"]

    # Check für verbleibende NaNs (Sollte 0 sein)
    remaining_nan = city_features_df[feature_cols].isna().sum().sum()
    print(f"\n  Validation: Remaining NoData in features: {remaining_nan}")

    # 6. Geometrie Mergen & Speichern
    city_features_gdf = city_trees[["tree_id", "geometry"]].merge(
        city_features_df, on="tree_id", how="inner"
    )

    output_path = OUTPUT_DIR / f"trees_with_features_{city}.gpkg"
    city_features_gdf.to_file(output_path, driver="GPKG")
    print(f"\n  ✅ Saved {len(city_features_gdf):,} trees to {output_path.name}")

    # 7. Summary Store (Jetzt wieder mit allen Daten wie gewünscht)
    n_final = len(city_features_gdf)
    all_results.append({
        "city": city,
        "trees_original": n_original,
        "trees_final": n_final,
        "trees_removed_total": n_original - n_final,
        "removal_percent": ((n_original - n_final) / n_original) * 100,
        "interpolated_trees": n_interpolated
    })

# Finales Summary anzeigen
print("\n" + "="*60)
summary_df = pd.DataFrame(all_results)
print(summary_df)

FEATURE EXTRACTION - FULL OPTIMIZED VERSION

--- Processing Berlin ---
  Trees in Cadastre: 219,900
  [S2] Sampling 12 months...


Months:   0%|          | 0/12 [00:00<?, ?it/s]


  CHM NoData Check:
    ✅ Removed 2 trees with CHM NoData
    Remaining: 219,898 trees

  NoData Handling (Sentinel-2):
    Trees with >3 months NoData: 29,429 (13.4%)
    Trees interpolated (1-3 months): 145,353
    ✅ Filtered and Interpolated. Remaining: 190,469 trees

  Validation: Remaining NoData in features: 0

  ✅ Saved 190,469 trees to trees_with_features_Berlin.gpkg

--- Processing Hamburg ---
  Trees in Cadastre: 78,577
  [S2] Sampling 12 months...


Months:   0%|          | 0/12 [00:00<?, ?it/s]


  CHM NoData Check:
    ✅ Removed 24 trees with CHM NoData
    Remaining: 78,553 trees

  NoData Handling (Sentinel-2):
    Trees with >3 months NoData: 30,034 (38.2%)
    Trees interpolated (1-3 months): 46,350
    ✅ Filtered and Interpolated. Remaining: 48,519 trees

  Validation: Remaining NoData in features: 0

  ✅ Saved 48,519 trees to trees_with_features_Hamburg.gpkg

--- Processing Rostock ---
  Trees in Cadastre: 17,500
  [S2] Sampling 12 months...


Months:   0%|          | 0/12 [00:00<?, ?it/s]


  CHM NoData Check:
    ✅ Removed 0 trees with CHM NoData
    Remaining: 17,500 trees

  NoData Handling (Sentinel-2):
    Trees with >3 months NoData: 809 (4.6%)
    Trees interpolated (1-3 months): 15,039
    ✅ Filtered and Interpolated. Remaining: 16,691 trees

  Validation: Remaining NoData in features: 0

  ✅ Saved 16,691 trees to trees_with_features_Rostock.gpkg

      city  trees_original  trees_final  trees_removed_total  removal_percent  \
0   Berlin          219900       190469                29431        13.383811   
1  Hamburg           78577        48519                30058        38.252924   
2  Rostock           17500        16691                  809         4.622857   

   interpolated_trees  
0              145353  
1               46350  
2               15039  


## 6. Summary

In [16]:
print("\n" + "="*60)
print("SUMMARY")
print("="*60)

summary_df = pd.DataFrame(all_results)
print(summary_df.to_string(index=False))

print(f"\n✅ Feature extraction complete for {len(CITIES)} cities!")
print(f"   Output directory: {OUTPUT_DIR}")
print(f"   Feature structure: 184 features + metadata")
print(f"     - Baum-Attribute: tree_id, city, genus_latin, species_latin, geometry")
print(f"     - CHM-Features (4): height_m, CHM_mean, CHM_max, CHM_std")
print(f"     - S2-Features (180): B02_01...RTVIcore_12")


SUMMARY
   city  trees_original  trees_final  trees_removed_total  removal_percent  interpolated_trees
 Berlin          219900       190469                29431        13.383811              145353
Hamburg           78577        48519                30058        38.252924               46350
Rostock           17500        16691                  809         4.622857               15039

✅ Feature extraction complete for 3 cities!
   Output directory: /content/drive/MyDrive/Studium/Geoinformation/Module/Projektarbeit/data/features
   Feature structure: 184 features + metadata
     - Baum-Attribute: tree_id, city, genus_latin, species_latin, geometry
     - CHM-Features (4): height_m, CHM_mean, CHM_max, CHM_std
     - S2-Features (180): B02_01...RTVIcore_12
