In [1]:
import pickle
import json
from dataclasses import dataclass

import polars as pl
import numpy as np
from sklearn.preprocessing import LabelEncoder

In [3]:
# Define a dataclass to hold the data
@dataclass
class TrainTest:
    feature_columns: list[str]
    label_encoder: LabelEncoder
    X: np.ndarray
    y: np.ndarray
    groups: np.ndarray
    X_test: np.ndarray
    y_test: np.ndarray
    groups_test: np.ndarray

In [4]:
# load class mapping
with open("data/classes.json", "r") as f:
    # Convert keys to integers for consistency
    class_mapping = {int(k): v for k, v in json.load(f).items()}
class_mapping

{100: 'Alive Vegetation',
 110: 'Mature Forest',
 120: 'Revegetation',
 121: 'With Trees (after clear cut)',
 122: 'Canopy closing (after thinning/defoliation)',
 123: 'Without Trees (shrubs and grasses, no reforestation visible)',
 200: 'Disturbed',
 210: 'Planned',
 211: 'Clear Cut',
 212: 'Thinning',
 213: 'Forestry Mulching (Non Forest Vegetation Removal)',
 220: 'Salvage',
 221: 'After Biotic Disturbance',
 222: 'After Abiotic Disturbance',
 230: 'Biotic',
 231: 'Bark Beetle (with decline)',
 232: 'Gypsy Moth (temporary)',
 240: 'Abiotic',
 241: 'Drought',
 242: 'Wildfire',
 244: 'Wind',
 245: 'Avalanche',
 246: 'Flood'}

In [5]:
# target_classes = [110, 211, 221, 222]
target_classes = [110, 211, 221, 222, 231, 242, 244]
# combining salvage clear cut / -> salvage
class_mapping[221] = "Clear Cut"
class_mapping[222] = "Clear Cut"

In [6]:
# load parquet target classes
signal_data = (
    pl.read_parquet("data/pixel_data.parquet")
    .filter(pl.col("labels").is_in(target_classes))
    .with_columns(
        pl.col.labels.replace_strict(class_mapping, return_dtype=pl.String),
        (
            pl.col("timestamps")
            - pl.col("timestamps").min().over("sample_id", "labels")
        ).alias("duration_since_last_flag"),
    )
    .filter(
        pl.col("timestamps").dt.month().is_between(5, 10, closed="both"),
        pl.col.SCL.is_in([4, 5, 6, 7]),
        (pl.col.duration_since_last_flag < pl.duration(days=60))
        | (pl.col.labels == "Mature Forest"),
    )
)
signal_data

B02,B03,B04,B05,B06,B07,B08,B8A,B11,B12,SCL,labels,sample_id,timestamps,duration_since_last_flag
u16,u16,u16,u16,u16,u16,u16,u16,u16,u16,u16,str,u16,date,duration[ms]
245,516,561,1288,3198,3903,3964,4265,2567,1417,4,"""Mature Forest""",2,2019-05-09,120d
353,628,648,1290,3239,4129,4476,4570,2475,1414,4,"""Mature Forest""",2,2019-05-14,125d
275,546,664,1259,3265,3996,4312,4411,2556,1405,4,"""Mature Forest""",2,2019-05-19,130d
231,557,571,1346,3666,4559,4736,5049,2622,1397,4,"""Mature Forest""",2,2019-05-29,140d
424,670,795,1334,3314,4338,4464,4732,2648,1523,4,"""Mature Forest""",2,2019-06-13,155d
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
161,348,272,648,1929,2370,2602,2631,1433,625,4,"""Wind""",4302,2023-09-08,22d
214,318,286,703,1758,2180,2491,2508,1471,733,4,"""Wind""",4302,2023-09-16,30d
171,344,313,781,1889,2318,2706,2692,1742,815,4,"""Wind""",4302,2023-09-18,32d
266,496,544,1328,2166,2593,2672,3117,2985,1633,4,"""Clear Cut""",4302,2023-09-28,2d


In [7]:
signal_data["labels"].value_counts()

labels,count
str,u32
"""Bark Beetle (with decline)""",1057
"""Wildfire""",553
"""Mature Forest""",475159
"""Clear Cut""",3761
"""Wind""",959


In [35]:
import geopandas as gpd

groups = (
    pl.from_pandas(gpd.read_parquet("data/samples.parquet").drop(columns="geometry"))
    .with_columns(
        cluster_id_int=pl.col("cluster_id").rle_id(),
    )
    .filter(
        pl.col.confidence == "high",
        ~pl.col.comment.str.contains("TCD"),
        ~pl.col.comment.str.contains("border"),
    )
)
groups

sample_id,original_sample_id,interpreter,dataset,source,source_description,s2_tile,cluster_id,cluster_description,comment,confidence,cluster_id_int
u64,i64,str,str,str,str,str,str,str,str,str,u32
0,0,"""pum""","""Evoland""","""EFFIS""","""Evoland Project, EFFIS Source …","""30SUF""","""0.0""","""Damage polygons""","""leichte Durchforstung 2021""","""high""",0
1,1,"""pum""","""Evoland""","""EFFIS""","""Evoland Project, EFFIS Source …","""30SUF""","""1.0""","""Damage polygons""","""Durchforstung_2021, kein Chang…","""high""",1
2,2,"""pum""","""Evoland""","""EFFIS""","""Evoland Project, EFFIS Source …","""30SUF""","""2.0""","""Damage polygons""","""Durchforstung_2021, kein Chang…","""high""",2
3,3,"""pum""","""Evoland""","""EFFIS""","""Evoland Project, EFFIS Source …","""30SUF""","""3.0""","""Damage polygons""","""Durchforstung 2021, unsicher""","""high""",3
4,4,"""pum""","""Evoland""","""EFFIS""","""Evoland Project, EFFIS Source …","""30SUF""","""5.0""","""Damage polygons""","""Durchforstung 2021, starke Dur…","""high""",4
…,…,…,…,…,…,…,…,…,…,…,…
4282,14467,"""vij""","""Windthrow""","""FORWIND + Copernicus Emergency…","""https://mapping.emergency.cope…",,"""LV20220807""","""Id of the Event, given as ISO2…","""no wt""","""high""",675
4283,15508,"""vij""","""Windthrow""","""FORWIND + Copernicus Emergency…","""https://mapping.emergency.cope…",,"""LV20220807""","""Id of the Event, given as ISO2…","""unclear salvage""","""high""",675
4285,14139,"""vij""","""Windthrow""","""FORWIND + Copernicus Emergency…","""https://mapping.emergency.cope…",,"""LV20220807""","""Id of the Event, given as ISO2…","""no wt""","""high""",675
4287,14238,"""vij""","""Windthrow""","""FORWIND + Copernicus Emergency…","""https://mapping.emergency.cope…",,"""LV20220807""","""Id of the Event, given as ISO2…","""no wt""","""high""",675


In [36]:
signal_data_with_cluster = signal_data.join(
    groups[["sample_id", "cluster_id_int"]],
    left_on="sample_id",
    right_on="sample_id",
    how="inner",
)
signal_data_with_cluster

B02,B03,B04,B05,B06,B07,B08,B8A,B11,B12,SCL,labels,sample_id,timestamps,duration_since_last_flag,cluster_id_int
u16,u16,u16,u16,u16,u16,u16,u16,u16,u16,u16,str,u16,date,duration[ms],u32
245,516,561,1288,3198,3903,3964,4265,2567,1417,4,"""Mature Forest""",2,2019-05-09,120d,2
353,628,648,1290,3239,4129,4476,4570,2475,1414,4,"""Mature Forest""",2,2019-05-14,125d,2
275,546,664,1259,3265,3996,4312,4411,2556,1405,4,"""Mature Forest""",2,2019-05-19,130d,2
231,557,571,1346,3666,4559,4736,5049,2622,1397,4,"""Mature Forest""",2,2019-05-29,140d,2
424,670,795,1334,3314,4338,4464,4732,2648,1523,4,"""Mature Forest""",2,2019-06-13,155d,2
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
302,417,317,600,1427,1708,1743,1948,1073,583,4,"""Wind""",4293,2023-09-26,40d,675
126,285,238,488,1409,1684,2155,1991,1111,518,4,"""Wind""",4293,2023-09-28,42d,675
1118,1150,993,1161,1468,1577,1662,1597,600,274,7,"""Wind""",4293,2023-10-01,45d,675
234,383,310,578,1520,1745,2448,2154,1058,584,4,"""Wind""",4293,2023-10-06,50d,675


In [37]:
le = LabelEncoder()
le.fit(signal_data["labels"])
le.classes_

array(['Bark Beetle (with decline)', 'Clear Cut', 'Mature Forest',
       'Wildfire', 'Wind'], dtype='<U26')

In [38]:
with open("data/train_ids.json", "r") as f:
    train_ids = json.load(f)
with open("data/val_ids.json", "r") as f:
    val_ids = json.load(f)

train = signal_data_with_cluster.filter(pl.col.sample_id.is_in(train_ids))
test = signal_data_with_cluster.filter(pl.col.sample_id.is_in(val_ids))

## Export structured train/test data for spectral pixel data

In [40]:
feature_columns = ["B02", "B03", "B04", "B05", "B06", "B07", "B08", "B8A", "B11", "B12"]

# Create an instance of the dataclass
data = TrainTest(
    feature_columns=feature_columns,
    label_encoder=le,
    X=train[feature_columns].to_numpy(writable=True),
    y=le.transform(train["labels"]),
    groups=train["sample_id"].to_numpy(writable=True),
    X_test=test[feature_columns].to_numpy(writable=True),
    y_test=le.transform(test["labels"]),
    groups_test=test["sample_id"].to_numpy(writable=True),
)

# Save as a single pickle file
with open("data/spectral_train_test.pkl", "wb") as file:
    pickle.dump(data, file)

In [36]:
data

TrainTest(feature_columns=['B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B11', 'B12'], label_encoder=LabelEncoder(), X=array([[ 245,  516,  561, ..., 4265, 2567, 1417],
       [ 353,  628,  648, ..., 4570, 2475, 1414],
       [ 275,  546,  664, ..., 4411, 2556, 1405],
       ...,
       [ 401,  472,  278, ..., 4251, 1461,  634],
       [ 266,  496,  544, ..., 3117, 2985, 1633],
       [ 453,  617,  696, ..., 2897, 2724, 1627]],
      shape=(19688, 10), dtype=uint16), y=array([1, 1, 1, ..., 1, 0, 0], shape=(19688,)), groups=array([  2,   2,   2, ..., 675, 675, 675], shape=(19688,), dtype=uint32), X_test=array([[1136, 1490, 1604, ..., 2288, 2277, 1760],
       [1011, 1358, 1458, ..., 2384, 2243, 1830],
       [1058, 1442, 1556, ..., 2514, 2307, 1745],
       ...,
       [ 221,  345,  548, ..., 1597, 1848, 1120],
       [ 536,  568,  651, ..., 1425, 1273, 1067],
       [ 510,  597,  625, ..., 1217,  960,  690]],
      shape=(2486, 10), dtype=uint16), y_test=array([1, 1, 1, ...,