In [None]:
%load_ext autoreload
%autoreload 2

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import tifffile
import time
from datetime import datetime
import pandas as pd
pd.set_option('display.max_colwidth', 250)
import numpy as np
import cv2
import matplotlib.pyplot as plt
from PIL import Image
import geopandas as gpd
from shapely.geometry import Polygon
from shapely.geometry import shape
import glob
from shapely.geometry import box
import rasterio
from rasterio.windows import Window
import json
from matplotlib.colors import LinearSegmentedColormap
red_to_green = LinearSegmentedColormap.from_list("RedGreen", ["red", "green"])

from src.train import train
from src.hyperparams import HyperParams, open_from_yaml
from src.data import *
# from src.validate import return_metrics_all_folds, visualize_single_model, cross_validation, visualize_s2_concat_bands_path
from src.visualize import visualize_s2_concat_bands_path, visualize_s1_path, scale_img, visualize_ls_concat_bands_path, visualize_cb_mux_path, visualize_cb_wpm_path

In [None]:
# max_cloud_nodata = 75
# s2_valid_data_events = []
# for event_id in events_sorted_by_area[:]:
#     # rand_int = np.random.randint(len(events_sorted_by_area))
#     # rand_int = k
#     # event_id = events_sorted_by_area[rand_int]
#     # s2_paths = sorted([k for k in glob.glob(f"data/{event_id}/sentinel2/*L2A*.tif") if not "ndvi" in k and not "ndwi" in k])
#     s1_paths = sorted([k for k in glob.glob(f"data/{event_id}/sentinel1/*.tif") if not "_nrpb" in k and not "_rfdi" in k and not "_rvi" in k and not "_vv_vh_ratio" in k])
#     if len(s1_paths) == 0:
#         continue
#     # s2_16d_paths = sorted([k for k in glob.glob(f"data/{event_id}/sentinel2/*16D*.tif") if not "ndvi" in k and not "ndwi" in k])
#     # ls_paths = sorted([k for k in glob.glob(f"data/{event_id}/landsat/*_L2SP_*.tif") if not "_nbr" in k and not "_ndwi" in k and not "_ndvi" in k])
#     # cb_wpm_paths = sorted([k for k in glob.glob(f"data/{event_id}/cbers4a/*_WPM_*.tif") if not "_ndvi" in k])
#     # cb_mux_paths = sorted([k for k in glob.glob(f"data/{event_id}/cbers4a/*_MUX_*.tif") if not "_ndvi" in k])
#     # print(f"{event_id}: {len(s2_paths):2} S2 L2A, {len(s2_16d_paths):2} S2 16D,{len(s1_paths):2} S1, {len(ls_paths):2} LS L2SP, {len(cb_wpm_paths):2} CB WPM, {len(cb_mux_paths):2} CB MUX")
#     # paths = s2_paths + s1_paths + ls_paths + cb_wpm_paths + cb_mux_paths + s2_16d_paths
#     # if len(paths) == 0:
#     #     continue
#     # continue
#     mask_path = glob.glob(f"data/{event_id}/*_mask.tif")[0]
#     mask_date = mask_path.split("/")[-1].split('_')[1]
#     mask_date_str = f"{mask_date[:4]}-{mask_date[4:6]}-{mask_date[6:8]}"
#     mask_date = datetime.strptime(mask_date, "%Y%m%d")
#     # print(mask_date)
    
#     date = mask_path.split("/")[-1].split('_')[1]
#     # Sort by the date right before the .tif
#     paths = sorted(s1_paths, key=lambda p: datetime.strptime(p.split('_')[-1].split('.')[0], "%Y%m%d"))
#     # break
#     paths_after, paths_before = [], []
#     for path in paths:
#         date = datetime.strptime(path.split('_')[-1].split('.')[0], "%Y%m%d")
#         if date > mask_date:
#             paths_after.append(path)
#         else:
#             paths_before.append(path)
#     print(f"{event_id}: {len(paths):2} S1 paths, {len(paths_before):2} before, {len(paths_after):2} after ({mask_date})")
#     if len(paths_before) > 0 and len(paths_after) > 0:
#         s2_valid_data_events.append(event_id)
# len(s2_valid_data_events)

# S2 single sensor, single before/after image

## Create cross validation dataset

In [None]:
df = gpd.read_file("data/mapbiomas_alerts.geojson")
events_sorted_by_area = df.sort_values("areaHa", ascending=False)["alertCode"].tolist()
df.shape

In [None]:
paths_before_, paths_after_ = [], []
clouds_nodata_before_, clouds_nodata_after_ = [], []
label_paths, deforest_pxs = [], []
s2_valid_data_events = []
seed = 1121
seed_count = 0
for event_id in events_sorted_by_area[:]:
    s2_paths = sorted([k for k in glob.glob(f"data/{event_id}/sentinel2/*L2A*.tif") if not "ndvi" in k and not "ndwi" in k and not "compressed" in k])
    if len(s2_paths) == 0:
        continue
    mask_path = glob.glob(f"data/{event_id}/*_mask.tif")[0]
    mask_date = mask_path.split("/")[-1].split('_')[1]
    mask_date_str = f"{mask_date[:4]}-{mask_date[4:6]}-{mask_date[6:8]}"
    mask_date = datetime.strptime(mask_date, "%Y%m%d")
    # print(mask_date)
    
    date = mask_path.split("/")[-1].split('_')[1]
    # Sort by the date right before the .tif
    paths = sorted(s2_paths, key=lambda p: datetime.strptime(p.split('_')[-1].split('.')[0], "%Y%m%d"))
    # break
    paths_after, paths_before = [], []
    clouds_nodata_after, clouds_nodata_before = [], []
    for path in paths:
        img_orig = tifffile.imread(path)
        clouds_perc = 100 * (img_orig[:, :, 2] > 3400).sum() / (img_orig.shape[0] * img_orig.shape[1])
        nodata_perc = 100 * (img_orig[:, :, 2] <= 1000).sum() / (img_orig.shape[0] * img_orig.shape[1])
        date = path.split("/")[-1].split('.')[0].split('_')[-1]
        date_str = f"{date[:4]}-{date[4:6]}-{date[6:8]}"
            
        date = datetime.strptime(path.split('_')[-1].split('.')[0], "%Y%m%d")
        if date > mask_date:
            paths_after.append(path)
            clouds_nodata_after.append(clouds_perc + nodata_perc)
        else:
            paths_before.append(path)
            clouds_nodata_before.append(clouds_perc + nodata_perc)
    print(f"{event_id}: {len(paths):2} S2 L2A paths, {len(paths_before):2} before, {len(paths_after):2} after, ({mask_date})")
    if len(clouds_nodata_before) > 0 and len(clouds_nodata_after) > 0:
        print(f" --> min nodata/clouds before: {min(clouds_nodata_before):5.1f}%, after: {min(clouds_nodata_after):5.1f}%")
    if len(paths_before) > 0 and len(paths_after) > 0:
        label_path = glob.glob(f"data/{event_id}/*_mask.tif")[0]
        label_paths.append(label_path)
        label = tifffile.imread(label_path)
        deforest_pxs.append((label == 255).sum())
        
        s2_valid_data_events.append(event_id)
        paths_after_.append(paths_after)
        clouds_nodata_after_.append(clouds_nodata_after)
        paths_before_.append(paths_before)
        clouds_nodata_before_.append(clouds_nodata_before)

    if len(paths_before) > 2: # Create a new no deforestation example from the before images
        before_indices = np.arange(len(paths_before))
        np.random.seed(seed + seed_count)
        seed_count += 1
        np.random.shuffle(before_indices)
        sorted_indices_new_example = sorted(before_indices[:2])
        sorted_paths_new_example = [path for k, path in enumerate(paths_before) if k in sorted_indices_new_example]
        
        s2_valid_data_events.append(f"{event_id}_no_change_before")
        label_paths.append("empty")
        deforest_pxs.append(0)
        paths_after_.append([sorted_paths_new_example[1]])
        clouds_nodata_after_.append([clouds_nodata_before[sorted_indices_new_example[1]]])
        paths_before_.append([sorted_paths_new_example[0]])
        clouds_nodata_before_.append([clouds_nodata_before[sorted_indices_new_example[0]]])
        
    if len(paths_after) > 2: # Create a new no deforestation example from the before images
        after_indices = np.arange(len(paths_after))
        np.random.seed(seed + seed_count)
        seed_count += 1
        np.random.shuffle(after_indices)
        sorted_indices_new_example = sorted(after_indices[:2])
        sorted_paths_new_example = [path for k, path in enumerate(paths_after) if k in sorted_indices_new_example]
        
        s2_valid_data_events.append(f"{event_id}_no_change_after")
        label_paths.append("empty")
        deforest_pxs.append(0)
        paths_after_.append([sorted_paths_new_example[1]])
        clouds_nodata_after_.append([clouds_nodata_after[sorted_indices_new_example[1]]])
        paths_before_.append([sorted_paths_new_example[0]])
        clouds_nodata_before_.append([clouds_nodata_after[sorted_indices_new_example[0]]])
len(s2_valid_data_events)

In [None]:
df = pd.DataFrame({"event_id": s2_valid_data_events, 
                   "deforest_pxs": deforest_pxs, "label_path": label_paths,
                  "paths_before": paths_before_, "clouds_nodata_before": clouds_nodata_before_,
                "paths_after": paths_after_, "clouds_nodata_after": clouds_nodata_after_})
df.shape

In [None]:
def best_before_path(row):
    # Take the closest deforestation date with clouds < 10%
    for clouds_nodata, path in zip(row["clouds_nodata_before"][::-1], row["paths_before"][::-1]):
        if clouds_nodata < 10:
            return path, clouds_nodata
    # Take the closest deforestation date with clouds < 30%
    for clouds_nodata, path in zip(row["clouds_nodata_before"][::-1], row["paths_before"][::-1]):
        if clouds_nodata < 30:
            return path, clouds_nodata
    # Take the closest deforestation date with clouds < 100%
    for clouds_nodata, path in zip(row["clouds_nodata_before"][::-1], row["paths_before"][::-1]):
        if clouds_nodata < 100:
            return path, clouds_nodata
    # Else return the closest path
    return row["paths_before"][-1], row["clouds_nodata_before"][-1]

def best_after_path(row):
    # Take the closest deforestation date with clouds < 10%
    for clouds_nodata, path in zip(row["clouds_nodata_after"], row["paths_after"]):
        if clouds_nodata < 10:
            return path, clouds_nodata
    # Take the closest deforestation  date with clouds < 30%
    for clouds_nodata, path in zip(row["clouds_nodata_after"], row["paths_after"]):
        if clouds_nodata < 30:
            return path, clouds_nodata
    # Take the closest deforestation date with clouds < 100%
    for clouds_nodata, path in zip(row["clouds_nodata_after"], row["paths_after"]):
        if clouds_nodata < 100:
            return path, clouds_nodata
    # Else return the closest path
    return row["paths_after"][0], row["clouds_nodata_after"][0]
    
df["path_best_before"] = df[["paths_before", "clouds_nodata_before"]].apply(lambda row: best_before_path(row)[0], axis=1)
df["path_best_after"] = df[["paths_after", "clouds_nodata_after"]].apply(lambda row: best_after_path(row)[0], axis=1)

df["clouds_nodata_best_before"] = df[["paths_before", "clouds_nodata_before"]].apply(lambda row: best_before_path(row)[1], axis=1)
df["clouds_nodata_best_after"] = df[["paths_after", "clouds_nodata_after"]].apply(lambda row: best_after_path(row)[1], axis=1)

In [None]:
df["clouds_nodata_best_before"].hist()

In [None]:
df["clouds_nodata_best_after"].hist()

In [None]:
print(df.shape)
df = df[(df["label_path"]=="empty") | ((df["clouds_nodata_best_before"] < 50) & (df["clouds_nodata_best_after"] < 50))]
print(df.shape)

In [None]:
df.columns

In [None]:
seed = 222221
while True:
    np.random.seed(seed)
    folds = [0, 1, 2, 3, 4] * (len(df) // 5 + 1) #np.random.choice([0, 1, 2, 3, 4], size=len(df))
    np.random.shuffle(folds)
    df["fold"] = folds[:len(df)]
    # print(df["fold"].value_counts())
    std_counts = df.groupby("fold")["deforest_pxs"].sum().std().item()
    # print(f'Seed {seed}: {df.groupby("fold")["deforest_pxs"].sum()}, Std = {std_counts:.0f}')
    if std_counts < 1250:
        break
    seed += 1
print(seed)

In [None]:
for idx, row in df.iterrows():
    output_path = row["path_best_before"].replace(".tif", "_uncompressed.tif")
    if not os.path.exists(output_path):
        with rasterio.open(row["path_best_before"]) as src:
            profile = src.profile
            profile["compress"] = None
            with rasterio.open(output_path, "w", **profile) as dst:
                dst.write(src.read())

    output_path = row["path_best_after"].replace(".tif", "_uncompressed.tif")
    if not os.path.exists(output_path):
        with rasterio.open(row["path_best_after"]) as src:
            profile = src.profile
            profile["compress"] = None
            with rasterio.open(output_path, "w", **profile) as dst:
                dst.write(src.read())

df["path_best_before"] = df["path_best_before"].apply(lambda x: x.replace(".tif", "_uncompressed.tif"))
df["path_best_after"] = df["path_best_after"].apply(lambda x: x.replace(".tif", "_uncompressed.tif"))

In [None]:
os.makedirs("catalogues", exist_ok=True)
df.to_pickle("catalogues/2025_08_19_s2_single.pkl")
print(df.shape)

In [None]:
df.groupby("fold")["deforest_pxs"].sum()

## Train

In [None]:
exp_nb = "S2_single_3-0"
path = "catalogues/2025_08_19_s2_single.pkl"
hps_dict = {
    ############
    # Data
    ############
    "df_path": path,

    ############
    # Training
    ############

    ## Experiment Setup
    "name": f"Exp{exp_nb}",

    ## Model
    "num_classes": 1,
    "input_channel": 15,
    "backbone": "timm_efficientnet_b1",
    "pretrained": 1,
    "model": "unetplusplus",

    # Training Setup
#     "resume": "trained_models/Exp7-4/fold_0/2022-02-05_23-13-25/best_metric_18_0.9768.pt",
    "print_freq": 500,
    "use_fp16": 0,
    "patience": 8, #5,
    "epoch_start_scheduler": 10,

    # Optimizer
    "lr": 0.001,
    "weight_decay": 0.0,

    ## Data Augmentation on CPU
    "train_crop_size": 256 + 32 + 32,
    "train_batch_size": 16, #32
    "cutmix_alpha": 0,
    "da_brightness_magnitude": 0.0,
    "da_contrast_magnitude": 0.0,

    # Data Augmentation on GPU
    "gpu_da_params": [0.25],
    
    ### Loss, Metric
    "alpha": 0.25,
#     "loss": "lovasz",
           }

for fold_nb in range(0, 5):
    hps = HyperParams(**hps_dict)
    hps.fold_nb = fold_nb
    num_batches = 250 // (hps.train_batch_size * torch.cuda.device_count())
    # num_batches = 10
    hps.num_batches = num_batches
    train_dataset, train_loader, val_dataset, val_loader = get_dataloaders(hps)
    continue

    best_metric, best_metric_epoch = train(hps, train_loader, val_loader)

In [None]:
df[df["event_id"]==1387993]

In [None]:
train_dataset.visualize(10)

In [None]:
for k in range(100):
    print(np.unique(train_dataset[k]["mask"]))

In [None]:
exp_nb = "S2_single_1-1"
path = "catalogues/2025_08_18_s2_single.pkl"
hps_dict = {
    ############
    # Data
    ############
    "df_path": path,

    ############
    # Training
    ############

    ## Experiment Setup
    "name": f"Exp{exp_nb}",

    ## Model
    "num_classes": 1,
    "input_channel": 15,
    "backbone": "timm_efficientnet_b1",
    "pretrained": 1,
    "model": "unetplusplus",

    # Training Setup
#     "resume": "trained_models/Exp7-4/fold_0/2022-02-05_23-13-25/best_metric_18_0.9768.pt",
    "print_freq": 500,
    "use_fp16": 0,
    "patience": 6, #5,
    "epoch_start_scheduler": 6,

    # Optimizer
    "lr": 0.001,
    "weight_decay": 0.0,

    ## Data Augmentation on CPU
    "train_crop_size": 400,
    "train_batch_size": 16, #32
    "cutmix_alpha": 0,
    "da_brightness_magnitude": 0.0,
    "da_contrast_magnitude": 0.0,

    # Data Augmentation on GPU
    "gpu_da_params": [0.25],
    
    ### Loss, Metric
    "alpha": 0.5,
#     "loss": "lovasz",
           }

for fold_nb in range(0, 5):
    hps = HyperParams(**hps_dict)
    hps.fold_nb = fold_nb
    num_batches = 250 // (hps.train_batch_size * torch.cuda.device_count())
    # num_batches = 10
    hps.num_batches = num_batches
    train_dataset, train_loader, val_dataset, val_loader = get_dataloaders(hps)
    # continue

    best_metric, best_metric_epoch = train(hps, train_loader, val_loader)

In [None]:
import shutil
shutil.rmtree("trained_models/ExpS2_single_1-2/")

In [None]:
exp_nb = "S2_single_1-2"
path = "catalogues/2025_08_18_s2_single.pkl"
hps_dict = {
    ############
    # Data
    ############
    "df_path": path,

    ############
    # Training
    ############

    ## Experiment Setup
    "name": f"Exp{exp_nb}",

    ## Model
    "num_classes": 1,
    "input_channel": 15,
    "backbone": "timm_efficientnet_b1",
    "pretrained": 1,
    "model": "unetplusplus",

    # Training Setup
#     "resume": "trained_models/Exp7-4/fold_0/2022-02-05_23-13-25/best_metric_18_0.9768.pt",
    "print_freq": 500,
    "use_fp16": 0,
    "patience": 6, #5,
    "epoch_start_scheduler": 10,

    # Optimizer
    "lr": 0.001,
    "weight_decay": 0.0,

    ## Data Augmentation on CPU
    "train_crop_size": 256 + 32 + 32,
    "train_batch_size": 16, #32
    "cutmix_alpha": 0,
    "da_brightness_magnitude": 0.0,
    "da_contrast_magnitude": 0.0,

    # Data Augmentation on GPU
    "gpu_da_params": [0.25],
    
    ### Loss, Metric
    "alpha": 0.5,
#     "loss": "lovasz",
           }

for fold_nb in range(0, 5):
    hps = HyperParams(**hps_dict)
    hps.fold_nb = fold_nb
    num_batches = 250 // (hps.train_batch_size * torch.cuda.device_count())
    # num_batches = 10
    hps.num_batches = num_batches
    train_dataset, train_loader, val_dataset, val_loader = get_dataloaders(hps)
    # continue

    best_metric, best_metric_epoch = train(hps, train_loader, val_loader)

## Create cross validation dataset with using the last after image

In [None]:
df = gpd.read_file("data/mapbiomas_alerts.geojson")
events_sorted_by_area = df.sort_values("areaHa", ascending=False)["alertCode"].tolist()
df.shape

In [None]:
paths_before_, paths_after_ = [], []
clouds_nodata_before_, clouds_nodata_after_ = [], []
s2_valid_data_events = []
for event_id in events_sorted_by_area[:]:
    s2_paths = sorted([k for k in glob.glob(f"data/{event_id}/sentinel2/*L2A*.tif") if not "ndvi" in k and not "ndwi" in k and not "uncompressed" in k])
    if len(s2_paths) == 0:
        continue
    mask_path = glob.glob(f"data/{event_id}/*_mask.tif")[0]
    mask_date = mask_path.split("/")[-1].split('_')[1]
    mask_date_str = f"{mask_date[:4]}-{mask_date[4:6]}-{mask_date[6:8]}"
    mask_date = datetime.strptime(mask_date, "%Y%m%d")
    # print(mask_date)
    
    date = mask_path.split("/")[-1].split('_')[1]
    # Sort by the date right before the .tif
    paths = sorted(s2_paths, key=lambda p: datetime.strptime(p.split('_')[-1].split('.')[0], "%Y%m%d"))
    # break
    paths_after, paths_before = [], []
    clouds_nodata_after, clouds_nodata_before = [], []
    for path in paths:
        # what_to_do_string = "SKIP"
        img_orig = tifffile.imread(path)
        # img_orig = img_orig.astype(np.float32)
        # img_orig = img_orig - 1000
        # img_orig[img_orig < 0] = 0
        clouds_perc = 100 * (img_orig[:, :, 2] > 3400).sum() / (img_orig.shape[0] * img_orig.shape[1])
        nodata_perc = 100 * (img_orig[:, :, 2] <= 1000).sum() / (img_orig.shape[0] * img_orig.shape[1])
        date = path.split("/")[-1].split('.')[0].split('_')[-1]
        date_str = f"{date[:4]}-{date[4:6]}-{date[6:8]}"
            
        date = datetime.strptime(path.split('_')[-1].split('.')[0], "%Y%m%d")
        if date > mask_date:
            paths_after.append(path)
            clouds_nodata_after.append(clouds_perc + nodata_perc)
        else:
            paths_before.append(path)
            clouds_nodata_before.append(clouds_perc + nodata_perc)
    print(f"{event_id}: {len(paths):2} S2 L2A paths, {len(paths_before):2} before, {len(paths_after):2} after, ({mask_date})")
    if len(clouds_nodata_before) > 0 and len(clouds_nodata_after) > 0:
        print(f" --> min nodata/clouds before: {min(clouds_nodata_before):5.1f}%, after: {min(clouds_nodata_after):5.1f}%")
    if len(paths_before) > 0 and len(paths_after) > 0:
        s2_valid_data_events.append(event_id)
        paths_after_.append(paths_after)
        clouds_nodata_after_.append(clouds_nodata_after)
        paths_before_.append(paths_before)
        clouds_nodata_before_.append(clouds_nodata_before)
len(s2_valid_data_events)

In [None]:
df = pd.DataFrame({"event_id": s2_valid_data_events, 
              "paths_before": paths_before_, "clouds_nodata_before": clouds_nodata_before_,
                  "paths_after": paths_after_, "clouds_nodata_after": clouds_nodata_after_})
df.shape

In [None]:
def best_before_path(row):
    # Take the earliest date with clouds < 10%
    for clouds_nodata, path in zip(row["clouds_nodata_before"], row["paths_before"]):
        if clouds_nodata < 10:
            return path, clouds_nodata
    # Take the earliest date with clouds < 30%
    for clouds_nodata, path in zip(row["clouds_nodata_before"], row["paths_before"]):
        if clouds_nodata < 30:
            return path, clouds_nodata
    # Take the earliest date with clouds < 100%
    for clouds_nodata, path in zip(row["clouds_nodata_before"], row["paths_before"]):
        if clouds_nodata < 100:
            return path, clouds_nodata
    # Else return the oldest path
    return row["paths_before"][0], row["clouds_nodata_before"][0]

def best_after_path(row):
    # Take the latest date with clouds < 10%
    for clouds_nodata, path in zip(row["clouds_nodata_after"][::-1], row["paths_after"][::-1]):
        if clouds_nodata < 10:
            return path, clouds_nodata
    # Take the latest date with clouds < 30%
    for clouds_nodata, path in zip(row["clouds_nodata_after"][::-1], row["paths_after"][::-1]):
        if clouds_nodata < 30:
            return path, clouds_nodata
    # Take the latest date with clouds < 100%
    for clouds_nodata, path in zip(row["clouds_nodata_after"][::-1], row["paths_after"][::-1]):
        if clouds_nodata < 100:
            return path, clouds_nodata
    # Else return the newest path
    return row["paths_after"][-1], row["clouds_nodata_after"][-1]
    
df["path_best_before"] = df[["paths_before", "clouds_nodata_before"]].apply(lambda row: best_before_path(row)[0], axis=1)
df["path_best_after"] = df[["paths_after", "clouds_nodata_after"]].apply(lambda row: best_after_path(row)[0], axis=1)

In [None]:
df["clouds_nodata_best_before"] = df[["paths_before", "clouds_nodata_before"]].apply(lambda row: best_before_path(row)[1], axis=1)
df["clouds_nodata_best_after"] = df[["paths_after", "clouds_nodata_after"]].apply(lambda row: best_after_path(row)[1], axis=1)

In [None]:
label_paths = []
deforest_pxs = []
for event_id in df["event_id"].tolist():
    label_path = glob.glob(f"data/{event_id}/*_mask.tif")[0]
    label_paths.append(label_path)
    label = tifffile.imread(label_path)
    deforest_pxs.append((label == 255).sum())
df["deforest_pxs"] = deforest_pxs
df["label_path"] = label_paths
df

In [None]:
df.columns

In [None]:
seed = 222221
while True:
    np.random.seed(seed)
    folds = [0, 1, 2, 3, 4] * (len(df) // 5 + 1) #np.random.choice([0, 1, 2, 3, 4], size=len(df))
    np.random.shuffle(folds)
    df["fold"] = folds[:len(df)]
    # print(df["fold"].value_counts())
    std_counts = df.groupby("fold")["deforest_pxs"].sum().std().item()
    # print(f'Seed {seed}: {df.groupby("fold")["deforest_pxs"].sum()}, Std = {std_counts:.0f}')
    if std_counts < 2000:
        break
    seed += 1

In [None]:
for idx, row in df.iterrows():
    with rasterio.open(row["path_best_before"]) as src:
        output_path = row["path_best_before"].replace(".tif", "_uncompressed.tif")
        if not os.path.exists(output_path):
            profile = src.profile
            profile["compress"] = None
            with rasterio.open(output_path, "w", **profile) as dst:
                dst.write(src.read())

    with rasterio.open(row["path_best_after"]) as src:
        output_path = row["path_best_after"].replace(".tif", "_uncompressed.tif")
        if not os.path.exists(output_path):
            profile = src.profile
            profile["compress"] = None
            with rasterio.open(output_path, "w", **profile) as dst:
                dst.write(src.read())

df["path_best_before"] = df["path_best_before"].apply(lambda x: x.replace(".tif", "_uncompressed.tif"))
df["path_best_after"] = df["path_best_after"].apply(lambda x: x.replace(".tif", "_uncompressed.tif"))

In [None]:
os.makedirs("catalogues", exist_ok=True)
df.to_pickle("catalogues/2025_08_18_s2_single_last_after_img.pkl")
print(df.shape)

In [None]:
df.groupby("fold")["deforest_pxs"].sum()

## Train

In [None]:
exp_nb = "S2_single_2-0"
path = "catalogues/2025_08_18_s2_single_last_after_img.pkl"
hps_dict = {
    ############
    # Data
    ############
    "df_path": path,

    ############
    # Training
    ############

    ## Experiment Setup
    "name": f"Exp{exp_nb}",

    ## Model
    "num_classes": 1,
    "input_channel": 15,
    "backbone": "timm_efficientnet_b1",
    "pretrained": 1,
    "model": "unetplusplus",

    # Training Setup
#     "resume": "trained_models/Exp7-4/fold_0/2022-02-05_23-13-25/best_metric_18_0.9768.pt",
    "print_freq": 500,
    "use_fp16": 0,
    "patience": 8, #5,
    "epoch_start_scheduler": 10,

    # Optimizer
    "lr": 0.001,
    "weight_decay": 0.0,

    ## Data Augmentation on CPU
    "train_crop_size": 256 + 32 + 32,
    "train_batch_size": 16, #32
    "cutmix_alpha": 0,
    "da_brightness_magnitude": 0.0,
    "da_contrast_magnitude": 0.0,

    # Data Augmentation on GPU
    "gpu_da_params": [0.25],
    
    ### Loss, Metric
    "alpha": 0.25,
#     "loss": "lovasz",
           }

for fold_nb in range(0, 5):
    hps = HyperParams(**hps_dict)
    hps.fold_nb = fold_nb
    num_batches = 250 // (hps.train_batch_size * torch.cuda.device_count())
    # num_batches = 10
    hps.num_batches = num_batches
    train_dataset, train_loader, val_dataset, val_loader = get_dataloaders(hps)
    # continue

    best_metric, best_metric_epoch = train(hps, train_loader, val_loader)