# JamUNet model trained with the spatial dataset - training and validation

This notebook was used for training and validating the model.

In [1]:
# # move to root directory

%cd ..

c:\Users\piete\OneDrive\Documenten\DSAIE MORPH\jamunet-morpho-braided


In [2]:
# reload modules to avoid restarting the notebook every time these are updated

%load_ext autoreload
%autoreload 2

In [None]:
# import modules 

import torch
import joblib
import copy

from torch.utils.data import DataLoader, TensorDataset
from torch.optim.lr_scheduler import StepLR

from model.train_eval import * 
from preprocessing.dataset_generation import create_full_dataset
from preprocessing.dataset_generation import combine_datasets
from postprocessing.save_results import *
from postprocessing.plot_results import *

# enable interactive widgets in Jupyter Notebook
%matplotlib inline
%matplotlib widget

In [4]:
# set the device where operations are performed
# if only one GPU is present you might need to remove the index "0" 
# torch.device('cuda:0') --> torch.device('cuda') / torch.cuda.get_device_name(0) --> torch.cuda.get_device_name() 

if torch.cuda.is_available():
    device = torch.device('cuda:0')
    print("CUDA Device Count: ", torch.cuda.device_count())
    print("CUDA Device Name: ", torch.cuda.get_device_name(0))
else:
    device = 'cpu'
    
print(f'Using device: {device}')

CUDA Device Count:  1
CUDA Device Name:  NVIDIA GeForce RTX 3070
Using device: cuda:0


In [5]:
import os
num_cpus = os.cpu_count()  # total logical cores
print("Logical CPU cores:", num_cpus)

Logical CPU cores: 16


In [6]:
# set common keys required for functions

train = 'training'
val = 'validation'
test = 'testing'

In [7]:
import os
print("Current working directory:", os.getcwd())

Current working directory: c:\Users\piete\OneDrive\Documenten\DSAIE MORPH\jamunet-morpho-braided


In [8]:
import os
import os
print("Working directory:", os.getcwd())

Working directory: c:\Users\piete\OneDrive\Documenten\DSAIE MORPH\jamunet-morpho-braided


In [9]:
# import os
# import numpy as np
# from natsort import natsorted
# from osgeo import gdal

# def merge_yearly_to_npy(base_dir, collection='JRC_GSW1_4_MonthlyHistory'):
#     """
#     Merge all monthly .tif/.npy files of each year into a single .npy file per year.
#     Keeps original folder structure.

#     base_dir: root folder of preprocessed_PIETER
#     collection: substring to match collection folders (e.g. 'JRC_GSW1_4_MonthlyHistory')
#     """
#     # loop over subfolders in base_dir
#     for folder_name in os.listdir(base_dir):
#         if collection in folder_name:
#             folder_path = os.path.join(base_dir, folder_name)
#             if not os.path.isdir(folder_path):
#                 continue
            
#             # collect all tif/npy files (you currently only have .tif)
#             all_files = [
#                 f for f in os.listdir(folder_path)
#                 if f.endswith('.tif') or f.endswith('.npy')
#             ]
#             all_files = natsorted(all_files)  # sort by date in filenames

#             if not all_files:
#                 print(f"No tif/npy files found in {folder_path}")
#                 continue

#             # group by year (first 4 characters of filename, e.g. '1987_12_01.tif')
#             files_by_year = {}
#             for f in all_files:
#                 year = f[:4]
#                 if not year.isdigit():
#                     print(f"Skipping file with unexpected name: {f}")
#                     continue
#                 if year not in files_by_year:
#                     files_by_year[year] = []
#                 files_by_year[year].append(os.path.join(folder_path, f))
            
#             # merge each year
#             for year, files in files_by_year.items():
#                 year_arrays = []
#                 for fpath in files:
#                     if fpath.endswith('.npy'):
#                         arr = np.load(fpath)
#                     elif fpath.endswith('.tif'):
#                         ds = gdal.Open(fpath)
#                         if ds is None:
#                             print(f"Could not open {fpath} with GDAL, skipping.")
#                             continue
#                         arr = ds.ReadAsArray().astype(np.float32)
#                         ds = None
#                     else:
#                         continue

#                     year_arrays.append(arr)

#                 # optional: sanity check that all arrays have the same shape
#                 shapes = {a.shape for a in year_arrays}
#                 if len(shapes) > 1:
#                     print(f"WARNING: Different shapes for year {year} in {folder_path}: {shapes}")
#                     # you can decide to continue or skip; here I skip:
#                     continue

#                 if len(year_arrays) != 12:
#                     print(f"Skipping {year} in {folder_path}, only {len(year_arrays)} files (expected 12)")
#                     continue

#                 # stack months along a new axis (0 = months)
#                 merged = np.stack(year_arrays, axis=0)  # shape: (12, H, W)

#                 # save merged file
#                 out_name = os.path.join(folder_path, f'{year}_allmonths.npy')
#                 np.save(out_name, merged)
#                 print(f'Merged {len(files)} files into {out_name}')


# # Example usage:
# merge_yearly_to_npy(r'C:\Users\piete\OneDrive\Documenten\DSAIE MORPH\jamunet-morpho-braided\data\satellite\preprocessed_PIETER')

In [10]:
# import os
# import glob
# import numpy as np
# from osgeo import gdal

# def tiff_to_npy(root_dir, scaled_classes=True):
#     """
#     Walk through all .tif files under root_dir and create .npy versions
#     in the same folders. Keeps folder structure intact.
#     """
#     tif_files = glob.glob(os.path.join(root_dir, "**", "*.tif"), recursive=True)
#     print(f"Found {len(tif_files)} .tif files")

#     for tif in tif_files:
#         npy_path = tif.replace(".tif", ".npy")
#         if os.path.exists(npy_path):
#             continue  # skip if .npy already exists

#         try:
#             ds = gdal.Open(tif)
#             arr = ds.ReadAsArray().astype(np.float32)

#             if scaled_classes:
#                 arr = arr.astype(np.int32)
#                 arr[arr == 0] = -1
#                 arr[arr == 1] = 0
#                 arr[arr == 2] = 1

#             np.save(npy_path, arr)
#             print(f"Saved {npy_path}")
#         except Exception as e:
#             print(f"Failed to process {tif}: {e}")

# # Usage
# dataset_path = r"data\satellite\preprocessed_PIETER"
# tiff_to_npy(dataset_path)


In [None]:
# import os
# import numpy as np
# import torch
# from torch.utils.data import Dataset


# class LazyDataset(Dataset):
#     """
#     Lazily load YEARLY .npy satellite images.

#     Assumes:
#       - combine_datasets(...) returns:
#           inputs: list of lists of yearly .npy paths
#           targets: list of single yearly .npy paths
#       - each yearly .npy has all 12 months (e.g. shape (12, H, W) or (12, C, H, W))

#     For each sample:
#       x = stack of (year_target - 1) yearly arrays  -> shape (T, ...)
#       y = 3rd month (index 2) of the target year    -> shape (...)
#     """

#     def __init__(
#         self,
#         train_val_test,
#         dir_folders=r"data\satellite\preprocessed_PIETER",
#         year_target=5,
#         scaled_classes=True,
#         dtype=torch.float32,
#     ):
#         self.train_val_test = train_val_test
#         self.dir_folders = dir_folders
#         self.year_target = year_target
#         self.scaled_classes = scaled_classes
#         self.dtype = dtype
#         self.samples = []

#         # gather paths only; do NOT load arrays here
#         for folder in os.listdir(dir_folders):
#             if train_val_test in folder:
#                 # folder e.g. "JRC_GSW1_4_MonthlyHistory_training_r1"
#                 reach_id = folder.split("_r", 1)[1]

#                 inputs, targets = combine_datasets(
#                     train_val_test,
#                     int(reach_id),
#                     year_target,
#                     dir_folders=dir_folders,
#                     scaled_classes=scaled_classes,
#                 )

#                 for inp_paths, tgt_path in zip(inputs, targets):
#                     # inp_paths: list of yearly .npy paths
#                     # tgt_path:  single yearly .npy path
#                     self.samples.append((inp_paths, tgt_path))

#     def __len__(self):
#         return len(self.samples)

#     def __getitem__(self, idx):
#         inp_paths, tgt_path = self.samples[idx]

#         # --- load input sequence (year_target - 1 yearly .npy files) ---
#         xs = [np.load(p, mmap_mode="r") for p in inp_paths]  # each: (12, H, W)
#         x_np = np.stack(xs, axis=0)  # (T, 12, H, W)

#         # flatten (T, 12) -> (T*12) along channels
#         if x_np.ndim != 4:
#             raise ValueError(f"Expected input shape (T, 12, H, W), got {x_np.shape}")
#         T, M, H, W = x_np.shape
#         x_np = x_np.reshape(T * M, H, W)  # (C, H, W) where C = T*12

#         # --- load target year and take ONLY month 3 ---
#         year_np = np.load(tgt_path, mmap_mode="r")  # (12, H, W)
#         y_np = year_np[2]  # 3rd month

#         x = torch.from_numpy(x_np).to(self.dtype)  # (C, H, W)
#         y = torch.from_numpy(y_np).to(self.dtype)  # (H, W) or (C', H, W) depending on saved format

#         return x, y


import os
import glob
import re
import numpy as np
import torch
from torch.utils.data import Dataset
from PIL import Image


class LazyDataset(Dataset):
    """
    Inputs:
        - stacked yearly .npy images (flattened months)
    Target:
        - always from .tiff files inside:
            data/satellite/preprocessed/<reach_folder>/

        Target is the 3rd tiff of the year (index 2)
    """

    def __init__(
        self,
        train_val_test,
        dir_folders=r"data\satellite\preprocessed_PIETER",
        target_root=r"data\satellite\dataset_month3",
        year_target=5,
        scaled_classes=True,
        dtype=torch.float32,
    ):
        self.train_val_test = train_val_test
        self.dir_folders = dir_folders
        self.target_root = target_root
        self.year_target = year_target
        self.scaled_classes = scaled_classes
        self.dtype = dtype
        self.samples = []

        for folder in os.listdir(dir_folders):
            if train_val_test in folder:
                reach_folder = folder
                reach_id = folder.split("_r", 1)[1]

                inputs, targets = combine_datasets(
                    train_val_test,
                    int(reach_id),
                    year_target,
                    dir_folders=dir_folders,
                    scaled_classes=scaled_classes,
                )

                for inp_paths, tgt_npy_path in zip(inputs, targets):
                    tiff_path = self._build_tiff_path_from_npy(
                        tgt_npy_path, reach_folder
                    )
                    self.samples.append((inp_paths, tiff_path))

    def _extract_year_from_path(self, path: str) -> str:
        base = os.path.basename(path)
        m = re.search(r"\d{4}", base)
        if m is None:
            raise ValueError(f"Could not extract year from: {base}")
        return m.group(0)

    def _build_tiff_path_from_npy(self, tgt_npy_path: str, reach_folder: str) -> str:
        year = self._extract_year_from_path(tgt_npy_path)
        tiff_dir = os.path.join(self.target_root, reach_folder)

        if not os.path.isdir(tiff_dir):
            raise FileNotFoundError(f"Target directory not found: {tiff_dir}")

        tiff_paths = sorted(
            glob.glob(os.path.join(tiff_dir, f"{year}_*.tif"))
            + glob.glob(os.path.join(tiff_dir, f"{year}_*.tiff"))
        )

        if len(tiff_paths) == 0:
            raise FileNotFoundError(
                f"No tiffs found for year {year} in {tiff_dir}"
            )
        if len(tiff_paths) > 1:
            raise ValueError(
                f"Expected 1 tiff for year {year} in {tiff_dir}, "
                f"found {len(tiff_paths)}: {tiff_paths}"
            )

        # single TIFF per year (already month 3)
        return tiff_paths[0]

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        inp_paths, tiff_path = self.samples[idx]

        # ---------- INPUT FROM .NPY ----------
        xs = [np.load(p, mmap_mode="r") for p in inp_paths]   # each: (12, H, W)
        x_np = np.stack(xs, axis=0)                           # (T, 12, H, W)

        if x_np.ndim != 4:
            raise ValueError(f"Expected (T,12,H,W), got {x_np.shape}")

        T, M, H, W = x_np.shape
        x_np = x_np.reshape(T * M, H, W)                      # (C, H, W)

        # ---------- TARGET FROM .TIFF (PIL) ----------
        if not os.path.exists(tiff_path):
            raise FileNotFoundError(f"Missing TIFF: {tiff_path}")

        with Image.open(tiff_path) as img:
            y_np = np.array(img)

        # if multi-band, reduce to one channel
        if y_np.ndim == 3:
            y_np = y_np[:, :, 0]

        x = torch.from_numpy(x_np).to(self.dtype)


        label = torch.from_numpy(y_np).long()
        valid_mask = (label != 0).float()       
        y_bin = (label == 2).long() 

        return x, y_bin, valid_mask

In [12]:
# load all datasets

# by default March images are used - if another month is used change the number (available months: 1-4)
dataset_path = r'data\satellite\preprocessed_PIETER'

dtype=torch.float32



train_set = LazyDataset('training', dir_folders=dataset_path, dtype=dtype)
val_set   = LazyDataset('validation', dir_folders=dataset_path, dtype=dtype)
test_set  = LazyDataset('testing', dir_folders=dataset_path, dtype=dtype)

# now you can safely check .npy files
import glob
files = glob.glob("data/satellite/preprocessed_PIETER/**/*.npy", recursive=True)
print(len(files), "NPY files found")
for f in files[:10]:
    print(f)



# train_set = create_full_dataset(train, dir_folders=dataset_path, device=device, dtype=dtype)
# val_set = create_full_dataset(val, dir_folders=dataset_path, device=device, dtype=dtype)
# test_set = create_full_dataset(test, dir_folders=dataset_path, device=device, dtype=dtype)



# from concurrent.futures import ProcessPoolExecutor
# from functools import partial

# build = partial(create_full_dataset, dir_folders=dataset_path, device=device, dtype=dtype)

# with ProcessPoolExecutor(max_workers=16) as ex:
#     train_set, val_set, test_set = ex.map(build, [train, val, test])

# zonder processpoolexecutor 2.40 seconds

1020 NPY files found
data/satellite/preprocessed_PIETER\JRC_GSW1_4_MonthlyHistory_testing_r1\1988_allmonths.npy
data/satellite/preprocessed_PIETER\JRC_GSW1_4_MonthlyHistory_testing_r1\1989_allmonths.npy
data/satellite/preprocessed_PIETER\JRC_GSW1_4_MonthlyHistory_testing_r1\1990_allmonths.npy
data/satellite/preprocessed_PIETER\JRC_GSW1_4_MonthlyHistory_testing_r1\1991_allmonths.npy
data/satellite/preprocessed_PIETER\JRC_GSW1_4_MonthlyHistory_testing_r1\1992_allmonths.npy
data/satellite/preprocessed_PIETER\JRC_GSW1_4_MonthlyHistory_testing_r1\1993_allmonths.npy
data/satellite/preprocessed_PIETER\JRC_GSW1_4_MonthlyHistory_testing_r1\1994_allmonths.npy
data/satellite/preprocessed_PIETER\JRC_GSW1_4_MonthlyHistory_testing_r1\1995_allmonths.npy
data/satellite/preprocessed_PIETER\JRC_GSW1_4_MonthlyHistory_testing_r1\1996_allmonths.npy
data/satellite/preprocessed_PIETER\JRC_GSW1_4_MonthlyHistory_testing_r1\1997_allmonths.npy


In [13]:
print(f'Training dataset samples: {len(train_set)},\n\
Validation dataset samples: {len(val_set)},\n\
Testing dataset samples: {len(test_set)}')

Training dataset samples: 840,
Validation dataset samples: 30,
Testing dataset samples: 30


**<span style="color:red">Attention!</span>** 
\
Uncomment the next cells if larger training, validation, and testing datasets are needed. These cells load all months datasets (January, February, March, and April) and then merge them into one dataset. 
\
Keep in mind that due to memory constraints, it is likely that not all four datasets can be loaded. 
\
Make sure to load the training, validation, and testing datasets in different cells to reduce memory issues.

In [14]:
# dataset_jan = r'data\satellite\dataset_month1'
# dataset_feb = r'data\satellite\dataset_month2'
# dataset_mar = r'data\satellite\dataset_month3'
# dataset_apr = r'data\satellite\dataset_month4'

# dtype=torch.float32

In [15]:
# inputs_train_jan, targets_train_jan = create_full_dataset(train, dir_folders=dataset_jan, device=device, dtype=dtype).tensors
# inputs_train_feb, targets_train_feb = create_full_dataset(train, dir_folders=dataset_feb, device=device, dtype=dtype).tensors
# inputs_train_mar, targets_train_mar = create_full_dataset(train, dir_folders=dataset_mar, device=device, dtype=dtype).tensors
# inputs_train_apr, targets_train_apr = create_full_dataset(train, dir_folders=dataset_apr, device=device, dtype=dtype).tensors

# inputs_train = torch.cat((inputs_train_jan, inputs_train_feb, inputs_train_mar, inputs_train_apr))
# targets_train = torch.cat((targets_train_jan, targets_train_feb, targets_train_mar, targets_train_apr))
# train_set = TensorDataset(inputs_train, targets_train)

In [16]:
# inputs_val_jan, targets_val_jan = create_full_dataset(val, dir_folders=dataset_jan, device=device, dtype=dtype).tensors
# inputs_val_feb, targets_val_feb = create_full_dataset(val, dir_folders=dataset_feb, device=device, dtype=dtype).tensors
# inputs_val_mar, targets_val_mar = create_full_dataset(val, dir_folders=dataset_mar, device=device, dtype=dtype).tensors
# inputs_val_apr, targets_val_apr = create_full_dataset(val, dir_folders=dataset_apr, device=device, dtype=dtype).tensors

# inputs_val = torch.cat((inputs_val_jan, inputs_val_feb, inputs_val_mar, inputs_val_apr))
# targets_val = torch.cat((targets_val_jan, targets_val_feb, targets_val_mar, targets_val_apr))
# val_set = TensorDataset(inputs_val, targets_val)

In [17]:
# inputs_test_jan, targets_test_jan = create_full_dataset(test, dir_folders=dataset_jan, device=device, dtype=dtype).tensors
# inputs_test_feb, targets_test_feb = create_full_dataset(test, dir_folders=dataset_feb, device=device, dtype=dtype).tensors
# inputs_test_mar, targets_test_mar = create_full_dataset(test, dir_folders=dataset_mar, device=device, dtype=dtype).tensors
# inputs_test_apr, targets_test_apr = create_full_dataset(test, dir_folders=dataset_apr, device=device, dtype=dtype).tensors

# inputs_test = torch.cat((inputs_test_jan, inputs_test_feb, inputs_test_mar, inputs_test_apr))
# targets_test = torch.cat((targets_test_jan, targets_test_feb, targets_test_mar, targets_test_apr))
# test_set = TensorDataset(inputs_test, targets_test)

**<span style="color:red">Attention!</span>** 
\
It is not needed to scale and normalize the dataset as the pixel values are already $[0, 1]$.
\
If scaling and normalization are performed anyways, then **the model inputs have to be changed** as the normalized datasets are used.

In [18]:
# normalize inputs and targets using the training dataset

# scaler_x, scaler_y = scaler(train_set)

# normalized_train_set = normalize_dataset(train_set, scaler_x, scaler_y)
# normalized_val_set = normalize_dataset(val_set, scaler_x, scaler_y)
# normalized_test_set = normalize_dataset(test_set, scaler_x, scaler_y)

In [19]:
# save scalers to be loaded in seperate notebooks (i.e., for testing the model)
# should not change unless seed is changed or augmentation increased (randomsplit changes)

# joblib.dump(scaler_x, r'model\scalers\scaler_x.joblib')
# joblib.dump(scaler_y, r'model\scalers\scaler_y.joblib')

In [20]:
# load JamUNet architecture

from model.st_unet.st_unet import *

# n_channels = train_set[0][0].shape[0]
n_channels = 48
n_classes = 2
init_hid_dim = 8
kernel_size = 3
pooling = 'max'

model = UNet3D(n_channels=n_channels, n_classes=n_classes, init_hid_dim=init_hid_dim, 
               kernel_size=kernel_size, pooling=pooling, bilinear=False, drop_channels=False)

In [21]:
# print model architecture

model

UNet3D(
  (inc): DoubleConv(
    (double_conv): Sequential(
      (0): Conv2d(48, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (conv3d): Conv3d(8, 8, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=False)
  )
  (down1): Down(
    (pooling): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (pool_conv): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): DoubleConv(
        (double_conv): Sequential(
          (0): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
        (conv3d): Conv3d(16, 16, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=False)
      )

In [22]:
# print total number of parameters and model size

num_parameters = sum(p.numel() for p in model.parameters())
print(f"Number of parameters: {num_parameters:.2e}")
model_size_MB = num_parameters * 4 / (1024 ** 2)  # assuming float32 precision
print(f"Model size: {model_size_MB:.2f} MB")

Number of parameters: 5.26e+05
Model size: 2.01 MB


**<span style="color:red">Attention!</span>** 
\
Since it is not needed to scale and normalize the dataset (see above), the input for the Data Loader are not the normalized datasets.
\
If normalization is performed, the normalized datasets become the inputs to the model.

In [23]:
# to make it faster
import torch
torch.backends.cudnn.benchmark = True

In [None]:
# hyperparameters
learning_rate = 0.05
batch_size = 16
num_epochs = 10
water_threshold = 0.5
physics = False    # no physics-induced loss terms in the training loss if False
alpha_er = 1e-4    # needed only if physics=True
alpha_dep = 1e-4   # needed only if physics=True

# optimizer to train the model with backpropagation
optimizer = torch.optim.Adam(params=model.parameters(), lr=learning_rate)

# scheduler for decreasing the learning rate 
# every tot epochs (step_size) with given factor (gamma)
step_size = 15     # set to None to remove the scheduler
gamma = 0.75       # set to None to remove the scheduler
if (step_size and gamma) is not None:
    scheduler = StepLR(optimizer, step_size = step_size, gamma = gamma)

# dataloaders to input data to the model in batches -- see note above if normalization is performed
# train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True) 
# val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False)
# test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)

# train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True, prefetch_factor=0)
# val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True, prefetch_factor=0)
# test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True, prefetch_factor=0)

In [25]:
# initialize training, validation losses and metrics
train_losses, val_losses = [], []
accuracies, precisions, recalls, f1_scores, csi_scores = [], [], [], [], []

# set classification loss - possible options: 'BCE', 'BCE_Logits', and 'Focal'



loss_f = 'BCE'




# set regression loss for physics-induced terms
# possible options: 'Huber', 'RMSE', and 'MAE'
loss_er_dep = 'Huber'




# IGNORE_LABEL = 255  # value to ignore in the loss

# def remap_labels(mask, nodata_value=-1, nonwater_value=0):
#     """
#     mask: (N, H, W) or (H, W) with original labels
#     returns: same shape, values in {0, 1, IGNORE_LABEL}
#     """
#     mask = mask.clone()

#     # nodata -> ignore
#     mask[mask == nodata_value] = IGNORE_LABEL

#     # non-water -> 0
#     mask[mask == nonwater_value] = 0

#     # everything else -> 1 (water)
#     mask[(mask != 0) & (mask != IGNORE_LABEL)] = 1

#     return mask.long()

# # classification loss: CrossEntropy with ignore_index for nodata
# criterion = nn.CrossEntropyLoss(ignore_index=IGNORE_LABEL)





for epoch in range(1, num_epochs+1):
    
    # update learning rate
    if (step_size and gamma) is not None:
        scheduler.step() # update the learning rate
    
    # model training
    train_loss = training_unet(model, train_loader, optimizer, water_threshold=water_threshold, 
                               device=device, loss_f=loss_f, physics=physics, alpha_er=alpha_er, 
                               alpha_dep=alpha_dep, loss_er_dep=loss_er_dep)

    # model validation
    val_loss, val_accuracy, val_precision, val_recall, val_f1_score, val_csi_score = validation_unet(model, val_loader, 
                                                                                                     device=device, loss_f=loss_f, 
                                                                                                     water_threshold=water_threshold)

    if epoch == 1:
        best_loss = val_loss
        best_recall = val_recall
    
    # save model with min val loss
    if val_loss<=best_loss:
        best_model = copy.deepcopy(model)
        best_loss = val_loss
        best_epoch = epoch
        count = 0
    # save model with max recall
    if val_recall>=best_recall:
        best_model_recall = copy.deepcopy(model)
        best_recall = val_recall
        best_epoch = epoch
        count = 0

    train_losses.append(train_loss)
    val_losses.append(val_loss)

    accuracies.append(val_accuracy)
    precisions.append(val_precision)
    recalls.append(val_recall)
    f1_scores.append(val_f1_score)
    csi_scores.append(val_csi_score)

    count += 1
    
    if epoch%1 == 0:
        print(f"Epoch: {epoch} | " +
              f"Training loss: {train_loss:.2e}, Validation loss: {val_loss:.2e}, Best validation loss: {best_loss:.2e} " + 
              f" | Metrics: Accuracy: {val_accuracy:.3f}, Precision: {val_precision:.3f}, Recall: {val_recall:.3f},\
 F1-score: {val_f1_score:.3f}, CSI-score: {val_csi_score:.3f}, Best recall: {best_recall:.3f}")
        if (step_size and gamma) is not None:
            print(f'Current learning rate: {scheduler.get_last_lr()[0]}')

Epoch: 1 | Training loss: 1.58e-01, Validation loss: 1.23e-01, Best validation loss: 1.23e-01  | Metrics: Accuracy: 0.953, Precision: 0.659, Recall: 0.750, F1-score: 0.701, CSI-score: 0.542, Best recall: 0.750
Current learning rate: 0.05
Epoch: 2 | Training loss: 1.06e-01, Validation loss: 1.04e-01, Best validation loss: 1.04e-01  | Metrics: Accuracy: 0.958, Precision: 0.721, Recall: 0.722, F1-score: 0.717, CSI-score: 0.561, Best recall: 0.750
Current learning rate: 0.05
Epoch: 3 | Training loss: 9.85e-02, Validation loss: 1.08e-01, Best validation loss: 1.04e-01  | Metrics: Accuracy: 0.957, Precision: 0.692, Recall: 0.785, F1-score: 0.734, CSI-score: 0.581, Best recall: 0.785
Current learning rate: 0.037500000000000006


KeyboardInterrupt: 

In [None]:
metrics = [accuracies, precisions, recalls, f1_scores, csi_scores]

In [None]:
# store training and validation losses and metrics to be stored in a .csv file for later postprocessing
# always check the dataset month key

save_losses_metrics(train_losses, val_losses, metrics, 'spatial', model, 3, init_hid_dim, 
                    kernel_size, pooling, learning_rate, step_size, gamma, batch_size, num_epochs, 
                    water_threshold, physics, alpha_er, alpha_dep, dir_output=r'model\losses_metrics')

**<span style="color:red">Attention!</span>** 
\
Always remember to rename the <code>save_path</code> file before running the whole notebook to avoid overwrting it.

In [None]:
# save model with min validation loss
# always check the dataset month key

save_model_path(best_model, 'spatial', 'loss', 3, init_hid_dim, kernel_size, pooling, learning_rate, 
                step_size, gamma, batch_size, num_epochs, water_threshold) 

In [None]:
# save model with max recall
# always check the dataset month key

save_model_path(best_model_recall, 'spatial', 'recall', 3, init_hid_dim, kernel_size, pooling, learning_rate, 
                step_size, gamma, batch_size, num_epochs, water_threshold) 

In [None]:
# test the min loss model - average loss and metrics

model_loss = copy.deepcopy(best_model)
test_loss, test_accuracy, test_precision, test_recall, test_f1_score, test_csi_score = validation_unet(model_loss, test_loader, device=device, loss_f = loss_f)

print(f'Average metrics for test dataset using model with best validation loss:\n\n\
{loss_f} loss:          {test_loss:.3e}\n\
Accuracy:          {test_accuracy:.3f}\n\
Precision:         {test_precision:.3f}\n\
Recall:            {test_recall:.3f}\n\
F1 score:          {test_f1_score:.3f}\n\
CSI score:         {test_csi_score:.3f}')

In [None]:
# test the max recall model - average loss and metrics

model_recall = copy.deepcopy(best_model_recall)
test_loss, test_accuracy, test_precision, test_recall, test_f1_score, test_csi_score = validation_unet(model_recall, test_loader, device=device, loss_f = loss_f)

print(f'Average metrics for test dataset using model with best validation recall:\n\n\
{loss_f} loss:          {test_loss:.3e}\n\
Accuracy:          {test_accuracy:.3f}\n\
Precision:         {test_precision:.3f}\n\
Recall:            {test_recall:.3f}\n\
F1 score:          {test_f1_score:.3f}\n\
CSI score:         {test_csi_score:.3f}')

In [None]:
plot_losses_metrics(train_losses, val_losses, metrics, model_recall, loss_f=loss_f)

In [None]:
# show_evolution(18, test_set, model_loss)

device = 'cuda:0'
model_loss = model_loss.to(device)

img, y, mask = train_set[0]
print(torch.unique(y))


show_all_images(18, test_set, model_loss, device=device)

In [None]:
# show_evolution(18, test_set, model_recall)
show_evolution_nolegend_nn(18, test_set, model_recall, device=device)

In [None]:
# show_evolution(18, val_set, model_recall)
show_evolution_nolegend_nn(18, val_set, model_recall, device=device)

In [None]:
single_roc_curve(model_loss, test_set, sample=18, device=device);

In [None]:
single_roc_curve(model_recall, test_set, sample=18, device=device);

In [None]:
get_total_roc_curve(model_loss, test_set, device=device);

In [None]:
get_total_roc_curve(model_recall, test_set, device=device);

In [None]:
single_pr_curve(model_loss, test_set, sample=19, device=device)

In [None]:
# show_evolution(18, test_set, model_loss)
show_evolution_nolegend_nn(18, test_set, model_loss, device=device)