# JamUNet model trained with the spatial dataset - training and validation

This notebook was used for training and validating the model.

In [1]:
%cd c:\Users\mathi\Desktop\TU Delft\TU Delft year 5\Data_science\Morphology_project\jamunet-morpho-braided

c:\Users\mathi\Desktop\TU Delft\TU Delft year 5\Data_science\Morphology_project\jamunet-morpho-braided


In [None]:
# import modules
import os
import torch
import joblib
import copy
import warnings
from torch.utils.data import DataLoader, TensorDataset, ConcatDataset
from torch.optim.lr_scheduler import StepLR
from model.train_eval import *
from preprocessing.dataset_generation import create_full_dataset
from postprocessing.save_results import *
from postprocessing.plot_results import *

# enable interactive widgets in Jupyter Notebook
%matplotlib inline
%matplotlib widget

# reload modules to avoid restarting the notebook every time these are updated
%load_ext autoreload
%autoreload 2

print("Current working directory:", os.getcwd())

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Current working directory: c:\Users\mathi\Desktop\TU Delft\TU Delft year 5\Data_science\Morphology_project\jamunet-morpho-braided


In [42]:
# set the device where operations are performed
# if only one GPU is present you might need to remove the index "0"
# torch.device('cuda:0') --> torch.device('cuda') / torch.cuda.get_device_name(0) --> torch.cuda.get_device_name()

if torch.cuda.is_available():
    device = torch.device('cuda:0')
    print("CUDA Device Count: ", torch.cuda.device_count())
    print("CUDA Device Name: ", torch.cuda.get_device_name(0))
else:
    device = 'cpu'

print(f'Using device: {device}')

num_cpus = os.cpu_count()  # total logical cores
print("\nLogical CPU cores:", num_cpus)

CUDA Device Count:  1
CUDA Device Name:  Quadro P2000
Using device: cuda:0

Logical CPU cores: 12


# Modified code

This code should be integrated into the dataset_generation.py file but failed multiple times without finding out what the exact error was. Therefore The function have been put in the ipynb file to safely loop over them.

In [30]:
# Suppress warnings for cleaner output
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

# ==========================================
# 1. HELPER FUNCTIONS
# ==========================================

def load_image_array(path, scaled_classes=True):
    '''Load a single image using Gdal and convert to float32 array.'''
    ds = gdal.Open(path)
    if ds is None: return None
    img_array = ds.ReadAsArray().astype(np.float32)

    if scaled_classes:
        img_array = img_array.astype(int)
        img_array[img_array==0] = -1
        img_array[img_array==1] = 0
        img_array[img_array==2] = 1
    
    return img_array

def load_avg(train_val_test, reach, year, dir_averages):
    '''Robustly loads the average image for a given year.'''
    # 1. Try standard path
    folder_name = f'average_{train_val_test}_r{reach}'
    filename = f'average_{year}_{train_val_test}_r{reach}.csv'
    full_path = os.path.join(dir_averages, folder_name, filename)
    
    if os.path.exists(full_path):
        return pd.read_csv(full_path, header=None).to_numpy()
    
    # 2. Fallback: Try finding it in generic folders
    potential_folders = [f'average_training_r{reach}', f'average_r{reach}']
    for folder in potential_folders:
        path_to_check = os.path.join(dir_averages, folder, filename)
        if os.path.exists(path_to_check):
            return pd.read_csv(path_to_check, header=None).to_numpy()
            
    return None

def create_list_images(train_val_test, reach, dir_folders, collection):
    '''Finds the correct reach folder and returns list of .tif images.'''
    list_dir_images = []
    
    if not os.path.exists(dir_folders):
        print(f"   ❌ Error: Directory not found: {dir_folders}")
        return []

    target_folder_path = None
    
    # Search for folder ending in "_rX"
    for folder_name in os.listdir(dir_folders):
        if folder_name.endswith(f'_r{reach}'):
            # If a specific filter (like 'training') is requested, prioritize it
            if train_val_test in folder_name or train_val_test == 'training':
                target_folder_path = os.path.join(dir_folders, folder_name)
                break
    
    # Fallback: Just match reach ID if specific tag not found
    if target_folder_path is None:
        for folder_name in os.listdir(dir_folders):
            if folder_name.endswith(f'_r{reach}'):
                target_folder_path = os.path.join(dir_folders, folder_name)
                break

    if target_folder_path is None:
        return []

    # Collect images
    sorted_files = sorted(os.listdir(target_folder_path))
    for image in sorted_files:
        if image.endswith('.tif'):
            list_dir_images.append(os.path.join(target_folder_path, image))
            
    return list_dir_images

# ==========================================
# 2. DATASET LOGIC
# ==========================================

def create_datasets(train_val_test, reach, year_target=5, nodata_value=-1, dir_folders=r'data\satellite\dataset', 
                    collection=r'JRC_GSW1_4_MonthlyHistory', scaled_classes=True):
    
    # 1. Get Images
    list_dir_images = create_list_images(train_val_test, reach, dir_folders, collection)
    if not list_dir_images: return [], [] 

    # 2. Load Images
    images_array = []
    loaded_years = []
    
    for idx, path in enumerate(list_dir_images):
        img = load_image_array(path, scaled_classes=scaled_classes)
        if img is not None:
            images_array.append(img)
            filename = os.path.basename(path)
            try:
                # Extract year (4 digits)
                parts = filename.replace('-', '_').split('_')
                year = next(p for p in parts if p.isdigit() and len(p) == 4)
                loaded_years.append(int(year))
            except:
                loaded_years.append(1988 + idx)

    # 3. Load Averages (DYNAMIC PATH LOGIC)
    # Go up: month_X -> preprocessed -> {River}_images -> averages
    parent_dir = os.path.dirname(dir_folders)
    river_base_dir = os.path.dirname(parent_dir)
    dir_averages_dynamic = os.path.join(river_base_dir, 'averages')

    avg_imgs = []
    for year in loaded_years:
        avg = load_avg(train_val_test, reach, year, dir_averages=dir_averages_dynamic)
        if avg is None:
            if len(images_array) > 0: avg = np.zeros_like(images_array[0])
            else: return [], []
        avg_imgs.append(avg)

    # 4. Replace No-Data
    good_images_array = [np.where(image == nodata_value, avg_imgs[i], image) 
                         for i, image in enumerate(images_array)]
        
    # 5. Create Sequences (n-to-1)
    input_dataset = []
    target_dataset = []
    
    if len(good_images_array) < year_target: return [], []

    for i in range(len(good_images_array) - year_target + 1):
        input_dataset.append(good_images_array[i : i + year_target - 1])
        target_dataset.append([good_images_array[i + year_target - 1]])

    return input_dataset, target_dataset

def combine_datasets(train_val_test, reach, year_target=5, nonwater_threshold=480000, nodata_value=-1, nonwater_value=0,   
                     dir_folders=r'data\satellite\dataset', collection=r'JRC_GSW1_4_MonthlyHistory', scaled_classes=True):
    
    # Create raw dataset
    input_dataset, target_dataset = create_datasets(
        train_val_test, reach, year_target, nodata_value, 
        dir_folders, collection, scaled_classes
    )

    filtered_inputs = []
    filtered_targets = []

    # Filter logic
    for input_images, target_image_seq in zip(input_dataset, target_dataset):
        is_input_good = True
        for img in input_images:
            if np.sum(img == nonwater_value) >= nonwater_threshold:
                is_input_good = False
                break
        
        if is_input_good:
            target_img = target_image_seq[0]
            if np.sum(target_img == nonwater_value) < nonwater_threshold:
                filtered_inputs.append(input_images)
                filtered_targets.append(target_img)

    return filtered_inputs, filtered_targets

def create_full_dataset(train_val_test, year_target=5, nonwater_threshold=480000, nodata_value=-1, nonwater_value=0, 
                        dir_folders=None, name_filter=None, collection=r'JRC_GSW1_4_MonthlyHistory', 
                        scaled_classes=True, device='cpu', dtype=torch.float32):
    
    all_inputs = []
    all_targets = []
    
    if not os.path.exists(dir_folders):
        print(f"   ⚠️ Path not found: {dir_folders}")
        return TensorDataset(torch.empty(0), torch.empty(0))

    potential_folders = [f for f in os.listdir(dir_folders) if os.path.isdir(os.path.join(dir_folders, f))]
    
    for folder_name in potential_folders:
        # FILTER: If name_filter is set (e.g. 'validation'), folder MUST contain it
        if name_filter and name_filter not in folder_name:
            continue

        try:
            reach_id = int(folder_name.split('_r')[-1])
        except:
            continue
            
        use_label = name_filter if name_filter else train_val_test
        
        inputs, targets = combine_datasets(
            train_val_test=use_label, 
            reach=reach_id, 
            year_target=year_target, 
            nonwater_threshold=nonwater_threshold,
            nodata_value=nodata_value, 
            nonwater_value=nonwater_value, 
            dir_folders=dir_folders, 
            collection=collection, 
            scaled_classes=scaled_classes
        )
        
        if len(inputs) > 0:
            all_inputs.extend(inputs)
            all_targets.extend(targets)

    if not all_inputs:
        return TensorDataset(torch.empty(0), torch.empty(0))

    # Convert to Tensor (Using CPU to avoid OOM)
    input_tensor = torch.tensor(np.array(all_inputs), dtype=dtype, device=device)
    target_tensor = torch.tensor(np.array(all_targets), dtype=dtype, device=device)
    
    return TensorDataset(input_tensor, target_tensor)

# ==========================================
# 3. MASTER EXECUTION SCRIPT
# ==========================================

# CONFIG
target_month = 4
device = 'cpu' # Keep data on CPU
dtype = torch.float32
base_dir = os.path.join('data', 'satellite')

train_lists, val_lists, test_lists = [], [], []

print(f"Building Datasets for MONTH {target_month}...\n")

rivers = ['Jamuna', 'Ganges', 'Indus', 'Ghangara']

for river in rivers:
    print(f"River: {river}")
    
    # Intelligent Path Finder
    possible_paths = [
        os.path.join(base_dir, f'{river}_images',  f'dataset_month{target_month}'),
        os.path.join(base_dir, f'{river}_images', f'dataset_month_{target_month}'),
        os.path.join(base_dir, f'{river}_images', 'preprocessed', f'month_{target_month}'),
        os.path.join(base_dir, f'{river}_images', 'preprocessed', f'month{target_month}')
    ]
    
    source_path = None
    for p in possible_paths:
        if os.path.exists(p):
            source_path = p
            break
            
    if source_path is None:
        print(f"    SKIPPING: Could not find month folder.")
        continue
    
    print(f"    Reading from: {source_path}")

    # Generate Datasets
    if river == 'Jamuna':
        # Split Jamuna
        ds_tr = create_full_dataset('training', dir_folders=source_path, name_filter='training', device=device, dtype=dtype)
        if len(ds_tr) > 0: train_lists.append(ds_tr)
            
        ds_val = create_full_dataset('validation', dir_folders=source_path, name_filter='validation', device=device, dtype=dtype)
        if len(ds_val) > 0: val_lists.append(ds_val)

        ds_test = create_full_dataset('testing', dir_folders=source_path, name_filter='testing', device=device, dtype=dtype)
        if len(ds_test) > 0: test_lists.append(ds_test)
        
    else:
        # Others -> Training
        ds_tr = create_full_dataset('training', dir_folders=source_path, name_filter=None, device=device, dtype=dtype)
        if len(ds_tr) > 0: train_lists.append(ds_tr)

# Final Merge
final_train_set = ConcatDataset(train_lists) if train_lists else None
final_val_set = ConcatDataset(val_lists) if val_lists else None
final_test_set = ConcatDataset(test_lists) if test_lists else None

# Summary
print("\n" + "="*40)
print(f"FINAL SUMMARY (Month {target_month})")
print("="*40)
print(f"Training Set:   {len(final_train_set) if final_train_set else 0} samples")
print(f"Validation Set: {len(final_val_set) if final_val_set else 0} samples")
print(f"Testing Set:    {len(final_test_set) if final_test_set else 0} samples")
print("="*40)

Building Datasets for MONTH 4...

River: Jamuna
    Reading from: data\satellite\Jamuna_images\dataset_month4
River: Ganges
    Reading from: data\satellite\Ganges_images\preprocessed\month_4
River: Indus
    Reading from: data\satellite\Indus_images\preprocessed\month_4
River: Ghangara
    Reading from: data\satellite\Ghangara_images\preprocessed\month_4

FINAL SUMMARY (Month 4)
Training Set:   703 samples
Validation Set: 18 samples
Testing Set:    18 samples


# Continue model

**<span style="color:red">Attention!</span>**
\
Uncomment the next cells if larger training, validation, and testing datasets are needed. These cells load all months datasets (January, February, March, and April) and then merge them into one dataset.
\
Keep in mind that due to memory constraints, it is likely that not all four datasets can be loaded.
\
Make sure to load the training, validation, and testing datasets in different cells to reduce memory issues.

In [31]:
# dataset_jan = r'data\satellite\dataset_month1'
# dataset_feb = r'data\satellite\dataset_month2'
# dataset_mar = r'data\satellite\dataset_month3'
# dataset_apr = r'data\satellite\dataset_month4'

# dtype=torch.float32

In [32]:
# inputs_train_jan, targets_train_jan = create_full_dataset(train, dir_folders=dataset_jan, device=device, dtype=dtype).tensors
# inputs_train_feb, targets_train_feb = create_full_dataset(train, dir_folders=dataset_feb, device=device, dtype=dtype).tensors
# inputs_train_mar, targets_train_mar = create_full_dataset(train, dir_folders=dataset_mar, device=device, dtype=dtype).tensors
# inputs_train_apr, targets_train_apr = create_full_dataset(train, dir_folders=dataset_apr, device=device, dtype=dtype).tensors

# inputs_train = torch.cat((inputs_train_jan, inputs_train_feb, inputs_train_mar, inputs_train_apr))
# targets_train = torch.cat((targets_train_jan, targets_train_feb, targets_train_mar, targets_train_apr))
# train_set = TensorDataset(inputs_train, targets_train)

In [33]:
# inputs_val_jan, targets_val_jan = create_full_dataset(val, dir_folders=dataset_jan, device=device, dtype=dtype).tensors
# inputs_val_feb, targets_val_feb = create_full_dataset(val, dir_folders=dataset_feb, device=device, dtype=dtype).tensors
# inputs_val_mar, targets_val_mar = create_full_dataset(val, dir_folders=dataset_mar, device=device, dtype=dtype).tensors
# inputs_val_apr, targets_val_apr = create_full_dataset(val, dir_folders=dataset_apr, device=device, dtype=dtype).tensors

# inputs_val = torch.cat((inputs_val_jan, inputs_val_feb, inputs_val_mar, inputs_val_apr))
# targets_val = torch.cat((targets_val_jan, targets_val_feb, targets_val_mar, targets_val_apr))
# val_set = TensorDataset(inputs_val, targets_val)

In [34]:
# inputs_test_jan, targets_test_jan = create_full_dataset(test, dir_folders=dataset_jan, device=device, dtype=dtype).tensors
# inputs_test_feb, targets_test_feb = create_full_dataset(test, dir_folders=dataset_feb, device=device, dtype=dtype).tensors
# inputs_test_mar, targets_test_mar = create_full_dataset(test, dir_folders=dataset_mar, device=device, dtype=dtype).tensors
# inputs_test_apr, targets_test_apr = create_full_dataset(test, dir_folders=dataset_apr, device=device, dtype=dtype).tensors

# inputs_test = torch.cat((inputs_test_jan, inputs_test_feb, inputs_test_mar, inputs_test_apr))
# targets_test = torch.cat((targets_test_jan, targets_test_feb, targets_test_mar, targets_test_apr))
# test_set = TensorDataset(inputs_test, targets_test)

**<span style="color:red">Attention!</span>**
\
It is not needed to scale and normalize the dataset as the pixel values are already $[0, 1]$.
\
If scaling and normalization are performed anyways, then **the model inputs have to be changed** as the normalized datasets are used.

In [35]:
# normalize inputs and targets using the training dataset

# scaler_x, scaler_y = scaler(train_set)

# normalized_train_set = normalize_dataset(train_set, scaler_x, scaler_y)
# normalized_val_set = normalize_dataset(val_set, scaler_x, scaler_y)
# normalized_test_set = normalize_dataset(test_set, scaler_x, scaler_y)

In [36]:
# save scalers to be loaded in seperate notebooks (i.e., for testing the model)
# should not change unless seed is changed or augmentation increased (randomsplit changes)

# joblib.dump(scaler_x, r'model\scalers\scaler_x.joblib')
# joblib.dump(scaler_y, r'model\scalers\scaler_y.joblib')

In [70]:
# load JamUNet architecture

from model.st_unet.st_unet_3D import *

n_channels = final_train_set[0][0].shape[0]
n_classes = 1
init_hid_dim = 8
kernel_size = 3
pooling = 'max'

model = UNet3D(n_channels=n_channels, n_classes=n_classes, init_hid_dim=init_hid_dim,
               kernel_size=kernel_size, pooling=pooling, bilinear=False, drop_channels=False)

In [71]:
# print model architecture

model

UNet3D(
  (inc): DoubleConv(
    (net): Sequential(
      (0): Conv3d(1, 8, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
      (1): BatchNorm3d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv3d(8, 8, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
      (4): BatchNorm3d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (down1): Down(
    (net): Sequential(
      (0): MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=0, dilation=1, ceil_mode=False)
      (1): DoubleConv(
        (net): Sequential(
          (0): Conv3d(8, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
          (1): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv3d(16, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1

In [72]:
# print total number of parameters and model size

num_parameters = sum(p.numel() for p in model.parameters())
print(f"Number of parameters: {num_parameters:.2e}")
model_size_MB = num_parameters * 4 / (1024 ** 2)  # assuming float32 precision
print(f"Model size: {model_size_MB:.2f} MB")

Number of parameters: 1.47e+06
Model size: 5.61 MB


**<span style="color:red">Attention!</span>**
\
Since it is not needed to scale and normalize the dataset (see above), the input for the Data Loader are not the normalized datasets.
\
If normalization is performed, the normalized datasets become the inputs to the model.

In [73]:
# hyperparameters
learning_rate = 0.05
batch_size = 16
num_epochs = 1
water_threshold = 0.5
physics = False    # no physics-induced loss terms in the training loss if False
alpha_er = 1e-4    # needed only if physics=True
alpha_dep = 1e-4   # needed only if physics=True

# optimizer to train the model with backpropagation
optimizer = torch.optim.Adam(params=model.parameters(), lr=learning_rate)

# scheduler for decreasing the learning rate
# every tot epochs (step_size) with given factor (gamma)
step_size = 15     # set to None to remove the scheduler
gamma = 0.75       # set to None to remove the scheduler
if (step_size and gamma) is not None:
    scheduler = StepLR(optimizer, step_size = step_size, gamma = gamma)

# dataloaders to input data to the model in batches -- see note above if normalization is performed
# Create DataLoaders
batch_size = 16
if final_train_set and len(final_train_set) > 0:
    train_loader = DataLoader(final_train_set, batch_size=batch_size, shuffle=True, drop_last=True)
    print("Train Loader Ready!")
if final_val_set and len(final_val_set) > 0:
    val_loader = DataLoader(final_val_set, batch_size=batch_size, shuffle=False, drop_last=False)
    print("Val Loader Ready!")
if final_test_set and len(final_test_set) > 0:
    test_loader = DataLoader(final_test_set, batch_size=batch_size, shuffle=False, drop_last=False)
    print("Test Loader Ready!")

Train Loader Ready!
Val Loader Ready!
Test Loader Ready!


In [74]:
# Test GPU before training
try:
    test_tensor = torch.randn(2, 2).to(device)
    print(f"✓ GPU test successful: {device}")
    del test_tensor
    torch.cuda.empty_cache()
except Exception as e:
    print(f"✗ GPU test failed: {e}")
    print("Switching to CPU...")
    device = 'cpu'

✓ GPU test successful: cuda:0


In [75]:
print(f"Validation samples: {len(final_val_set)}")
print(f"Batches in val_loader: {len(val_loader)}")

Validation samples: 18
Batches in val_loader: 2


In [76]:
# initialize training, validation losses and metrics
train_losses, val_losses = [], []
accuracies, precisions, recalls, f1_scores, csi_scores = [], [], [], [], []
count = 0

# set classification loss - possible options: 'BCE', 'BCE_Logits', and 'Focal'
loss_f = 'BCE'
# set regression loss for physics-induced terms
# possible options: 'Huber', 'RMSE', and 'MAE'
loss_er_dep = 'Huber'

for epoch in range(1, num_epochs+1):

    # update learning rate
    if (step_size and gamma) is not None:
        scheduler.step() # update the learning rate

    # model training
    train_loss = training_unet(model, train_loader, optimizer, water_threshold=water_threshold,
                               device=device, loss_f=loss_f, physics=physics, alpha_er=alpha_er,
                               alpha_dep=alpha_dep, loss_er_dep=loss_er_dep)

    # model validation
    val_loss, val_accuracy, val_precision, val_recall, val_f1_score, val_csi_score = validation_unet(model, val_loader,
                                                                                                     device=device, loss_f=loss_f,
                                                                                                     water_threshold=water_threshold)

    if epoch == 1:
        best_loss = val_loss
        best_recall = val_recall

    # save model with min val loss
    if val_loss<=best_loss:
        best_model = copy.deepcopy(model)
        best_loss = val_loss
        best_epoch = epoch
        count = 0
    # save model with max recall
    if val_recall>=best_recall:
        best_model_recall = copy.deepcopy(model)
        best_recall = val_recall
        best_epoch = epoch
        count = 0

    train_losses.append(train_loss)
    val_losses.append(val_loss)

    accuracies.append(val_accuracy)
    precisions.append(val_precision)
    recalls.append(val_recall)
    f1_scores.append(val_f1_score)
    csi_scores.append(val_csi_score)

    count += 1

    if epoch%1 == 0:
        print(f"Epoch: {epoch} | " +
              f"Training loss: {train_loss:.2e}, Validation loss: {val_loss:.2e}, Best validation loss: {best_loss:.2e} " +
              f" | Metrics: Accuracy: {val_accuracy:.3f}, Precision: {val_precision:.3f}, Recall: {val_recall:.3f},\
 F1-score: {val_f1_score:.3f}, CSI-score: {val_csi_score:.3f}, Best recall: {best_recall:.3f}")
        if (step_size and gamma) is not None:
            print(f'Current learning rate: {scheduler.get_last_lr()[0]}')

OutOfMemoryError: CUDA out of memory. Tried to allocate 62.00 MiB. GPU 0 has a total capacity of 4.00 GiB of which 0 bytes is free. Of the allocated memory 10.79 GiB is allocated by PyTorch, and 98.57 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
metrics = [accuracies, precisions, recalls, f1_scores, csi_scores]

In [None]:
# store training and validation losses and metrics to be stored in a .csv file for later postprocessing
# always check the dataset month key

save_losses_metrics(train_losses, val_losses, metrics, 'spatial', model, 3, init_hid_dim,
                    kernel_size, pooling, learning_rate, step_size, gamma, batch_size, num_epochs,
                    water_threshold, physics, alpha_er, alpha_dep, dir_output=r'model\losses_metrics')

**<span style="color:red">Attention!</span>**
\
Always remember to rename the <code>save_path</code> file before running the whole notebook to avoid overwrting it.

In [None]:
# save model with min validation loss
# always check the dataset month key

save_model_path(best_model, 'spatial', 'loss', 3, init_hid_dim, kernel_size, pooling, learning_rate,
                step_size, gamma, batch_size, num_epochs, water_threshold)

In [None]:
# save model with max recall
# always check the dataset month key

save_model_path(best_model_recall, 'spatial', 'recall', 3, init_hid_dim, kernel_size, pooling, learning_rate,
                step_size, gamma, batch_size, num_epochs, water_threshold)

In [None]:
# test the min loss model - average loss and metrics

model_loss = copy.deepcopy(best_model)
test_loss, test_accuracy, test_precision, test_recall, test_f1_score, test_csi_score = validation_unet(model_loss, test_loader, device=device, loss_f = loss_f)

print(f'Average metrics for test dataset using model with best validation loss:\n\n\
{loss_f} loss:          {test_loss:.3e}\n\
Accuracy:          {test_accuracy:.3f}\n\
Precision:         {test_precision:.3f}\n\
Recall:            {test_recall:.3f}\n\
F1 score:          {test_f1_score:.3f}\n\
CSI score:         {test_csi_score:.3f}')

In [None]:
# test the max recall model - average loss and metrics

model_recall = copy.deepcopy(best_model_recall)
test_loss, test_accuracy, test_precision, test_recall, test_f1_score, test_csi_score = validation_unet(model_recall, test_loader, device=device, loss_f = loss_f)

print(f'Average metrics for test dataset using model with best validation recall:\n\n\
{loss_f} loss:          {test_loss:.3e}\n\
Accuracy:          {test_accuracy:.3f}\n\
Precision:         {test_precision:.3f}\n\
Recall:            {test_recall:.3f}\n\
F1 score:          {test_f1_score:.3f}\n\
CSI score:         {test_csi_score:.3f}')

In [None]:
plot_losses_metrics(train_losses, val_losses, metrics, model_recall, loss_f=loss_f)

In [None]:
show_evolution(18, test_set, model_loss)

In [None]:
show_evolution(18, test_set, model_recall)

In [None]:
show_evolution(18, val_set, model_recall)

In [None]:
single_roc_curve(model_loss, test_set, sample=18, device=device);

In [None]:
single_roc_curve(model_recall, test_set, sample=18, device=device);

In [None]:
get_total_roc_curve(model_loss, test_set, device=device);

In [None]:
get_total_roc_curve(model_recall, test_set, device=device);

In [None]:
single_pr_curve(model_loss, test_set, sample=19, device=device)

In [None]:
show_evolution(18, test_set, model_loss)