# JamUNet: Multi-River Training & Validation

This notebook handles the training and validation of the JamUNet model. To ensure the model generalizes across braided river morphologies while maintaining high accuracy on the target reach, we utilize a multi-domain dataset encompassing the Jamuna, Ganges, Indus, and Ghangara rivers. In case you want to add any new rivers, make sure to execute the following preliminary scripts for all rivers to ensure the data is standardized and the temporal sequences are correctly structured:

- `preliminary\edit_satellite_img.ipynb`: Handles coordinate alignment, remapping pixel classes (Water, Land, No-data), and standardized cropping.

- `preliminary\create_dataset_monthly.ipynb`: Groups images into monthly bins and generates the seasonal averages required for "No-Data" gap filling.

The notenbook is built around the st_unet3D model and training, validation and test loops which can be found in:

- `model\st_unet\st_unet_3D.py`

- `model\train_eval.py` 

**Training Strategy & Data Integration**

The dataset is integrated using a "Domain Adaptation" approach to optimize performance specifically for the Jamuna River:

- Spatial Aggregation: Reaches from the Ganges, Indus, and Ghangara rivers are added to the training set to provide a large, diverse set of braiding patterns.

- Targeted Validation: The Jamuna River is partitioned into distinct spatial reaches. Validation and Testing sets consist exclusively of Jamuna data to ensure the performance metrics reflect the model's accuracy on the specific target domain.

- Importance Sampling: A WeightedRandomSampler is utilized in the Training DataLoader. This ensures the model encounters Jamuna training samples more frequently than samples from the auxiliary rivers, biasing the U-Net weights toward Jamuna-specific morphology.

**Hardware Management**

- CPU Loading: Datasets are kept on CPU RAM to manage the large memory footprint of multi-river imagery.

- Batch-wise GPU Transfer: To prevent GPU memory errors, data is moved to the GPU device in small, controlled batches during the forward and backward passes.

## Step 1: Set up
Setting up working directory, environment and checking available GPUn and CPU

In [1]:
%cd c:\Users\mathi\Desktop\TU Delft\TU Delft year 5\Data_science\Morphology_project\jamunet-morpho-braided

c:\Users\mathi\Desktop\TU Delft\TU Delft year 5\Data_science\Morphology_project\jamunet-morpho-braided


In [2]:
# import modules
import os
import torch
import joblib
import copy
import warnings
from torch.utils.data import DataLoader, TensorDataset, ConcatDataset
from torch.optim.lr_scheduler import StepLR
from model.train_eval import *
from preprocessing.dataset_generation_modified import create_full_dataset
from postprocessing.save_results import *
from postprocessing.plot_results import *
from osgeo import gdal

## enable interactive widgets in Jupyter Notebook
%matplotlib inline
%matplotlib widget

### reload modules to avoid restarting the notebook every time these are updated
%load_ext autoreload
%autoreload 2

In [3]:
base_dir = os.path.join('data', 'satellite')

In [4]:
# set the device where operations are performed
if torch.cuda.is_available():
    device = torch.device('cuda:0')
    print("CUDA Device Count: ", torch.cuda.device_count())
    print("CUDA Device Name: ", torch.cuda.get_device_name(0))
else:
    device = 'cpu'

print(f'Using device: {device}')

num_cpus = os.cpu_count()  # total logical cores
print("\nLogical CPU cores:", num_cpus)

CUDA Device Count:  1
CUDA Device Name:  Quadro P2000
Using device: cuda:0

Logical CPU cores: 12


## Step 2: Set up samples

The code bellow loops over the available rivers and the preprocessed samples in each one of them. Since the structure of the Jamuna river and the other rivers differs slightly, the loop looks over several options.

In [None]:
# Initiate dataset building
# define parameters

target_month = 3
device = 'cpu' # Keep data on CPU
dtype = torch.float32

train_lists, val_lists, test_lists = [], [], []

print(f"Building Datasets for MONTH {target_month}...\n")

rivers = ['Jamuna', 'Ganges', 'Indus', 'Ghangara']

for river in rivers:
    print(f"River: {river}")
    
    # Intelligent Path Finder
    possible_paths = [
        os.path.join(base_dir, f'{river}_images', f'dataset_month{target_month}'),
        os.path.join(base_dir, f'{river}_images', f'dataset_month_{target_month}'),
        os.path.join(base_dir, f'{river}_images', 'preprocessed', f'month_{target_month}'),
        os.path.join(base_dir, f'{river}_images', 'preprocessed', f'month{target_month}')
    ]
    
    source_path = None
    for p in possible_paths:
        if os.path.exists(p):
            source_path = p
            break
            
    if source_path is None:
        print(f"    SKIPPING: Could not find month folder.")
        continue

    # 1. GENERATE DATASETS
    if river == 'Jamuna':
        # Use specific filter for Jamuna to avoid mixing val/test into training
        ds_tr = create_full_dataset('training', dir_folders=source_path, name_filter='training', device=device, dtype=dtype)
        
        # Pull Val and Test only for Jamuna
        ds_val = create_full_dataset('validation', dir_folders=source_path, name_filter='validation', device=device, dtype=dtype)
        if len(ds_val) > 0: val_lists.append(ds_val)
    
        ds_test = create_full_dataset('testing', dir_folders=source_path, name_filter='testing', device=device, dtype=dtype)
        if len(ds_test) > 0: test_lists.append(ds_test)
    else:
        # For other rivers without 'training' tags, name_filter=None is correct
        ds_tr = create_full_dataset('training', dir_folders=source_path, name_filter=None, device=device, dtype=dtype)

    # 2. SINGLE APPEND (This handles all rivers correctly)
    if ds_tr and len(ds_tr) > 0:
        train_lists.append(ds_tr)
        print(f"    Added {len(ds_tr)} training samples from {river}.")

# Final Merge
final_train_set = ConcatDataset(train_lists) if train_lists else None
final_val_set = ConcatDataset(val_lists) if val_lists else None
final_test_set = ConcatDataset(test_lists) if test_lists else None

# Summary
print("\n" + "="*40)
print(f"FINAL SUMMARY (Month {target_month})")
print("="*40)
print(f"Training Set:   {len(final_train_set) if final_train_set else 0} samples")
print(f"Validation Set: {len(final_val_set) if final_val_set else 0} samples")
print(f"Testing Set:    {len(final_test_set) if final_test_set else 0} samples")
print("="*40)

Building Datasets for MONTH 3...

River: Jamuna
    Reading from: data\satellite\Jamuna_images\dataset_month3




River: Ganges
    Reading from: data\satellite\Ganges_images\preprocessed\month_3
River: Indus
    Reading from: data\satellite\Indus_images\preprocessed\month_3
River: Ghangara
    Reading from: data\satellite\Ghangara_images\preprocessed\month_3

FINAL SUMMARY (Month 3)
Training Set:   791 samples
Validation Set: 14 samples
Testing Set:    18 samples


**<span style="color:red">Attention!</span>**
\
Uncomment the next cells if larger training, validation, and testing datasets are needed. These cells load all months datasets (January, February, March, and April) and then merge them into one dataset.
\
Keep in mind that due to memory constraints, it is likely that not all four datasets can be loaded.
\
Make sure to load the training, validation, and testing datasets in different cells to reduce memory issues.

In [7]:
# dataset_jan = r'data\satellite\dataset_month1'
# dataset_feb = r'data\satellite\dataset_month2'
# dataset_mar = r'data\satellite\dataset_month3'
# dataset_apr = r'data\satellite\dataset_month4'

# dtype=torch.float32

In [8]:
# inputs_train_jan, targets_train_jan = create_full_dataset(train, dir_folders=dataset_jan, device=device, dtype=dtype).tensors
# inputs_train_feb, targets_train_feb = create_full_dataset(train, dir_folders=dataset_feb, device=device, dtype=dtype).tensors
# inputs_train_mar, targets_train_mar = create_full_dataset(train, dir_folders=dataset_mar, device=device, dtype=dtype).tensors
# inputs_train_apr, targets_train_apr = create_full_dataset(train, dir_folders=dataset_apr, device=device, dtype=dtype).tensors

# inputs_train = torch.cat((inputs_train_jan, inputs_train_feb, inputs_train_mar, inputs_train_apr))
# targets_train = torch.cat((targets_train_jan, targets_train_feb, targets_train_mar, targets_train_apr))
# train_set = TensorDataset(inputs_train, targets_train)

In [9]:
# inputs_val_jan, targets_val_jan = create_full_dataset(val, dir_folders=dataset_jan, device=device, dtype=dtype).tensors
# inputs_val_feb, targets_val_feb = create_full_dataset(val, dir_folders=dataset_feb, device=device, dtype=dtype).tensors
# inputs_val_mar, targets_val_mar = create_full_dataset(val, dir_folders=dataset_mar, device=device, dtype=dtype).tensors
# inputs_val_apr, targets_val_apr = create_full_dataset(val, dir_folders=dataset_apr, device=device, dtype=dtype).tensors

# inputs_val = torch.cat((inputs_val_jan, inputs_val_feb, inputs_val_mar, inputs_val_apr))
# targets_val = torch.cat((targets_val_jan, targets_val_feb, targets_val_mar, targets_val_apr))
# val_set = TensorDataset(inputs_val, targets_val)

In [10]:
# inputs_test_jan, targets_test_jan = create_full_dataset(test, dir_folders=dataset_jan, device=device, dtype=dtype).tensors
# inputs_test_feb, targets_test_feb = create_full_dataset(test, dir_folders=dataset_feb, device=device, dtype=dtype).tensors
# inputs_test_mar, targets_test_mar = create_full_dataset(test, dir_folders=dataset_mar, device=device, dtype=dtype).tensors
# inputs_test_apr, targets_test_apr = create_full_dataset(test, dir_folders=dataset_apr, device=device, dtype=dtype).tensors

# inputs_test = torch.cat((inputs_test_jan, inputs_test_feb, inputs_test_mar, inputs_test_apr))
# targets_test = torch.cat((targets_test_jan, targets_test_feb, targets_test_mar, targets_test_apr))
# test_set = TensorDataset(inputs_test, targets_test)

**<span style="color:red">Attention!</span>**
\
It is not needed to scale and normalize the dataset as the pixel values are already $[0, 1]$.
\
If scaling and normalization are performed anyways, then **the model inputs have to be changed** as the normalized datasets are used.

In [11]:
# normalize inputs and targets using the training dataset

# scaler_x, scaler_y = scaler(train_set)

# normalized_train_set = normalize_dataset(train_set, scaler_x, scaler_y)
# normalized_val_set = normalize_dataset(val_set, scaler_x, scaler_y)
# normalized_test_set = normalize_dataset(test_set, scaler_x, scaler_y)

In [12]:
# save scalers to be loaded in seperate notebooks (i.e., for testing the model)
# should not change unless seed is changed or augmentation increased (randomsplit changes)

# joblib.dump(scaler_x, r'model\scalers\scaler_x.joblib')
# joblib.dump(scaler_y, r'model\scalers\scaler_y.joblib')

## STEP 3: Set up model

Load one of the different models to use for training.

In [13]:
# load JamUNet architecture

from model.st_unet.st_unet_3D import *

n_channels = final_train_set[0][0].shape[0]
n_classes = 1
init_hid_dim = 8
kernel_size = 3
pooling = 'max'

model = UNet3D(n_channels=n_channels, n_classes=n_classes, init_hid_dim=init_hid_dim,
               kernel_size=kernel_size, pooling=pooling, bilinear=False, drop_channels=False)

# print model architecture

print(model)

# print total number of parameters and model size

num_parameters = sum(p.numel() for p in model.parameters())
print(f"Number of parameters: {num_parameters:.2e}")
model_size_MB = num_parameters * 4 / (1024 ** 2)  # assuming float32 precision
print(f"Model size: {model_size_MB:.2f} MB")

UNet3D(
  (inc): DoubleConv(
    (net): Sequential(
      (0): Conv3d(1, 8, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
      (1): BatchNorm3d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv3d(8, 8, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
      (4): BatchNorm3d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (down1): Down(
    (net): Sequential(
      (0): MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=0, dilation=1, ceil_mode=False)
      (1): DoubleConv(
        (net): Sequential(
          (0): Conv3d(8, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
          (1): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv3d(16, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1

**<span style="color:red">Attention!</span>**
\
Since it is not needed to scale and normalize the dataset (see above), the input for the Data Loader are not the normalized datasets.
\
If normalization is performed, the normalized datasets become the inputs to the model.

### Step 3.1: Defining hyperparameters and running Dataloaders

In [16]:
# hyperparameters
learning_rate = 0.05
batch_size = 16
num_epochs = 1
water_threshold = 0.5
physics = False    # no physics-induced loss terms in the training loss if False
alpha_er = 1e-4    # needed only if physics=True
alpha_dep = 1e-4   # needed only if physics=True

# optimizer to train the model with backpropagation
optimizer = torch.optim.Adam(params=model.parameters(), lr=learning_rate)

# scheduler for decreasing the learning rate
# every tot epochs (step_size) with given factor (gamma)
step_size = 15     # set to None to remove the scheduler
gamma = 0.75       # set to None to remove the scheduler
if (step_size and gamma) is not None:
    scheduler = StepLR(optimizer, step_size = step_size, gamma = gamma)

In [17]:
from torch.utils.data import WeightedRandomSampler
batch_size = 16

# 1. Define the importance weights for each river
# We prioritize Jamuna to ensure the model optimizes for our target domain
river_importance = {
    'Jamuna': 3.0,
    'Ganges': 1.0,
    'Indus': 1.0,
    'Ghangara': 1.0
}

# 2. Identify the river order used when building train_lists
# This must match the order in your 'rivers' loop
rivers_in_order = ['Jamuna', 'Ganges', 'Indus', 'Ghangara']

sample_weights = []

# 3. Assign the weight to every individual sample in the combined dataset
for i, dataset in enumerate(train_lists):
    river_name = rivers_in_order[i]
    weight = river_importance[river_name]
    
    # Extend the weight list by the number of samples in this river's dataset
    sample_weights.extend([weight] * len(dataset))

# 4. Create the Sampler
# num_samples=len(sample_weights) ensures the epoch length remains the same
sampler = WeightedRandomSampler(
    weights=sample_weights, 
    num_samples=len(sample_weights), 
    replacement=True
)

# 5. Initialize the Training DataLoader
# Note: 'shuffle' must be False when using a sampler
train_loader = DataLoader(
    final_train_set, 
    batch_size=batch_size, 
    sampler=sampler, 
    pin_memory=True,
    drop_last=True
)
# dataloaders to input data to the model in batches -- see note above if normalization is performed
# Create DataLoaders
if final_val_set and len(final_val_set) > 0:
    val_loader = DataLoader(final_val_set, batch_size=batch_size, shuffle=False, drop_last=False)
    print("Val Loader Ready!")
if final_test_set and len(final_test_set) > 0:
    test_loader = DataLoader(final_test_set, batch_size=batch_size, shuffle=False, drop_last=False)
    print("Test Loader Ready!")

print(f"WeightedRandomSampler initialized successfully.")
print(f"Total training samples: {len(sample_weights)}")

Val Loader Ready!
Test Loader Ready!
WeightedRandomSampler initialized successfully.
Total training samples: 791


In [18]:
if torch.cuda.is_available():
    device = torch.device('cuda:0')
    print("CUDA Device Count: ", torch.cuda.device_count())
    print("CUDA Device Name: ", torch.cuda.get_device_name(0))
else:
    device = 'cpu'

print(f'Using device: {device}')

num_cpus = os.cpu_count()  # total logical cores
print("\nLogical CPU cores:", num_cpus)

CUDA Device Count:  1
CUDA Device Name:  Quadro P2000
Using device: cuda:0

Logical CPU cores: 12


In [19]:
# Test GPU before training
try:
    test_tensor = torch.randn(2, 2).to(device)
    print(f"✓ GPU test successful: {device}")
    del test_tensor
    torch.cuda.empty_cache()
except Exception as e:
    print(f"✗ GPU test failed: {e}")
    print("Switching to CPU...")
    device = 'cpu'

✓ GPU test successful: cuda:0


### Step 3.2: Run training loop

In [None]:
# initialize training, validation losses and metrics
train_losses, val_losses = [], []
accuracies, precisions, recalls, f1_scores, csi_scores = [], [], [], [], []
count = 0

# set classification loss - possible options: 'BCE', 'BCE_Logits', and 'Focal'
loss_f = 'BCE'
# set regression loss for physics-induced terms
# possible options: 'Huber', 'RMSE', and 'MAE'
loss_er_dep = 'Huber'

for epoch in range(1, num_epochs+1):

    # update learning rate
    if (step_size and gamma) is not None:
        scheduler.step() # update the learning rate

    # model training
    train_loss = training_unet(model, train_loader, optimizer, water_threshold=water_threshold,
                               device=device, loss_f=loss_f, physics=physics, alpha_er=alpha_er,
                               alpha_dep=alpha_dep, loss_er_dep=loss_er_dep)

    # model validation
    val_loss, val_accuracy, val_precision, val_recall, val_f1_score, val_csi_score = validation_unet(model, val_loader,
                                                                                                     device=device, loss_f=loss_f,
                                                                                                     water_threshold=water_threshold)

    if epoch == 1:
        best_loss = val_loss
        best_recall = val_recall

    # save model with min val loss
    if val_loss<=best_loss:
        best_model = copy.deepcopy(model)
        best_loss = val_loss
        best_epoch = epoch
        count = 0
    # save model with max recall
    if val_recall>=best_recall:
        best_model_recall = copy.deepcopy(model)
        best_recall = val_recall
        best_epoch = epoch
        count = 0

    train_losses.append(train_loss)
    val_losses.append(val_loss)

    accuracies.append(val_accuracy)
    precisions.append(val_precision)
    recalls.append(val_recall)
    f1_scores.append(val_f1_score)
    csi_scores.append(val_csi_score)

    count += 1

    if epoch%1 == 0:
        print(f"Epoch: {epoch} | " +
              f"Training loss: {train_loss:.2e}, Validation loss: {val_loss:.2e}, Best validation loss: {best_loss:.2e} " +
              f" | Metrics: Accuracy: {val_accuracy:.3f}, Precision: {val_precision:.3f}, Recall: {val_recall:.3f},\
 F1-score: {val_f1_score:.3f}, CSI-score: {val_csi_score:.3f}, Best recall: {best_recall:.3f}")
        if (step_size and gamma) is not None:
            print(f'Current learning rate: {scheduler.get_last_lr()[0]}')

In [None]:
metrics = [accuracies, precisions, recalls, f1_scores, csi_scores]

### Step 3.3: Save model losses and model paths

In [None]:
save_losses_metrics(
    train_losses, 'Jamuna_Ganges_Ghangara_Indus', # This goes 2nd because it is 2nd in your 'def'
    val_losses, metrics, 'spatial', model, 3, 
    init_hid_dim, kernel_size, pooling, learning_rate, step_size, gamma, batch_size, 
    num_epochs,water_threshold, physics, alpha_er, 
    alpha_dep, dir_output='model/losses_metrics')

**<span style="color:red">Attention!</span>**
\
Always remember to rename the <code>save_path</code> file before running the whole notebook to avoid overwrting it.

In [None]:
# save model with min validation loss
# always check the dataset month key
# save model with min validation loss
# always check the dataset month key

save_model_path(best_model,'Jamuna_Ganges_Ghangara_Indus', 'spatial', 'loss', 3, init_hid_dim, kernel_size, pooling, learning_rate,
                step_size, gamma, batch_size, num_epochs, water_threshold)

In [None]:
# save model with max recall
# always check the dataset month key

save_model_path(best_model_recall,'Jamuna_Ganges_Ghangara_Indus', 'spatial', 'recall', 3, init_hid_dim, kernel_size, pooling, learning_rate,
                step_size, gamma, batch_size, num_epochs, water_threshold)

## Step 4: Evaluate model performance 

In [None]:
# test the min loss model - average loss and metrics

model_loss = copy.deepcopy(best_model)
test_loss, test_accuracy, test_precision, test_recall, test_f1_score, test_csi_score = validation_unet(model_loss, test_loader, device=device, loss_f = loss_f)

print(f'Average metrics for test dataset using model with best validation loss:\n\n\
{loss_f} loss:          {test_loss:.3e}\n\
Accuracy:          {test_accuracy:.3f}\n\
Precision:         {test_precision:.3f}\n\
Recall:            {test_recall:.3f}\n\
F1 score:          {test_f1_score:.3f}\n\
CSI score:         {test_csi_score:.3f}')

In [None]:
# test the max recall model - average loss and metrics

model_recall = copy.deepcopy(best_model_recall)
test_loss, test_accuracy, test_precision, test_recall, test_f1_score, test_csi_score = validation_unet(model_recall, test_loader, device=device, loss_f = loss_f)

print(f'Average metrics for test dataset using model with best validation recall:\n\n\
{loss_f} loss:          {test_loss:.3e}\n\
Accuracy:          {test_accuracy:.3f}\n\
Precision:         {test_precision:.3f}\n\
Recall:            {test_recall:.3f}\n\
F1 score:          {test_f1_score:.3f}\n\
CSI score:         {test_csi_score:.3f}')

In [None]:
%matplotlib inline

plot_losses_metrics(train_losses, val_losses, metrics, model_recall, loss_f=loss_f)

In [None]:
show_evolution_nolegend_nn(len(final_test_set)-1, final_test_set, model_loss)

In [None]:
show_evolution_nolegend_nn(len(final_test_set)-1, final_test_set, model_recall)

In [None]:
show_evolution_nolegend_nn(len(final_val_set)-1, final_val_set, model_recall)

In [None]:
single_roc_curve(model_loss, final_test_set, sample=len(final_test_set)-1, device=device);

In [None]:
single_roc_curve(model_recall, final_test_set, sample=len(final_test_set)-1, device=device);

In [None]:
get_total_roc_curve(model_loss, final_test_set, device=device);

In [None]:
get_total_roc_curve(model_recall, final_test_set, device=device);

In [None]:
single_pr_curve(model_loss, final_test_set, sample=19, device=device)