In [1]:
import geopandas as gpd

In [2]:
samples = gpd.read_parquet("data/samples.parquet")

# Validation data sampling

For validation we are mostly going to use the cluster ids. 

We will aim for a rough 90/10 split.

For HRVPP we can't use cluster_ids since those are too large (entire S2 tile) -> the validation set wouldn't have the necessary variance

For Windthrow we can try taking a single windthrow event as validation. This should make sure that the algorithm couldn't learn from timing of the event, or from specific spatial structure for each windthrow event. However taking only a single event would not provide good variance over other parameters of windthrow events (can't show performance for winter storms vs summer storms etc.)

In [3]:
samples["dataset"].unique()

<StringArray>
['Evoland', 'HRVPP', 'Windthrow']
Length: 3, dtype: string

In [4]:
samples.groupby("dataset")["cluster_id"].nunique()

dataset
Evoland      484
HRVPP         14
Windthrow     12
Name: cluster_id, dtype: int64

In [11]:
# evoland validation
evoland_cluster_samples = (
    samples.query("dataset=='Evoland'")["cluster_id"]
    .drop_duplicates()
    .sample(frac=0.1, random_state=42)
)
evoland_val_samples = samples.query(
    "dataset=='Evoland' & cluster_id in @evoland_cluster_samples"
)

In [13]:
# hrvpp validation
# we're just sampling randomly on sample-id
hrvpp_val_samples = samples.query("dataset=='HRVPP'").sample(frac=0.1, random_state=42)

## Windthrow

First let's find out what makes sense to sample

In [14]:
# 30% of samples:
len(samples.query("dataset=='Windthrow'")) * 0.1

29.400000000000002

In [15]:
samples.query("dataset=='Windthrow'")["cluster_id"].value_counts()

cluster_id
CH20170802    38
PL20170817    36
DE20180118    34
LV20220807    32
NO20211119    32
SI20200205    27
IT20181028    25
DE20171110    25
IT20181029    23
SE20181028     9
FR20200122     7
AU20181028     6
Name: count, dtype: Int64

From this it seems like we could barely get away with sampling two events. It makes sense to sample a summer and a winter/autumn event.
Let's for now pick Austria and Norway

In [16]:
val_events = ["NO20211119", "AU20181028"]
wt_val_samples = samples.query("cluster_id in @val_events")

## Compile train and val id lists

In [17]:
val_ids = sorted(
    evoland_val_samples["sample_id"].to_list()
    + hrvpp_val_samples["sample_id"].to_list()
    + wt_val_samples["sample_id"].to_list()
)
train_ids = sorted(list(set(samples["sample_id"].to_list()) - set(val_ids)))

In [21]:
len(samples["sample_id"].unique()) * 0.1

389.20000000000005

In [19]:
len(val_ids)

433

In [22]:
import json

with open("data/val_ids.json", "w") as f:
    json.dump(val_ids, f)
with open("data/train_ids.json", "w") as f:
    json.dump(train_ids, f)