# Reproducible Training of nnU-Net

Here is the code to reproduce the training of nnU-Net according to the dataset you want and initialization (not yet available).
This Jupyter Notebook is available on the branch test_leo (``git checkout test_leo``).

Very Important: Before doing anything on this notebook, you should open a service Onyxia entitled "Vscode-pytorch-gpu".

**What is missing?**
- Early stopping (heuristic: 80 epochs which lasts ~5h)
- Different initializations

## 1. Requirements


Python libraries required to run training and handle document downloading / uploading:

In [5]:
!pip install nnunetv2 tqdm s3fs
from tqdm import tqdm
import torch
import s3fs
from pathlib import Path
from tqdm import tqdm
import subprocess
import threading
import time
import os

Collecting argparse (from unittest2->batchgenerators>=0.25.1->nnunetv2)
  Using cached argparse-1.4.0-py2.py3-none-any.whl.metadata (2.8 kB)
Using cached argparse-1.4.0-py2.py3-none-any.whl (23 kB)
Installing collected packages: argparse
Successfully installed argparse-1.4.0


Before training the models, you need to enter your credentials. Example: email --> blabla.blabla@ensae.fr, name --> username Onyxia.

In [None]:
email = input("Enter your email ENSAE: ")
name = input("Enter your username Onyxia: ")

subprocess.run(["git", "config", "--global", "user.email", email])
subprocess.run(["git", "config", "--global", "user.name", name])

print(f"Git configured with email : {email} and username : {name}")

Git configured with email : leo.leroy@ensae.fr and username : leoacpr


Now, you must enter your S3 private keys. They are available on Onyxia > Account > Connexion au stockage. 

In [6]:
aws_access_key_id = input("Enter your AWS_ACCESS_KEY_ID: ")
aws_secret_access_key = input("Enter your AWS_SECRET_ACCESS_KEY: ")
aws_session_token = input("Enter your AWS_SESSION_TOKEN: ")

# Environment variables
os.environ["AWS_ACCESS_KEY_ID"] = aws_access_key_id
os.environ["AWS_SECRET_ACCESS_KEY"] = aws_secret_access_key
os.environ["AWS_SESSION_TOKEN"] = aws_session_token

print("AWS keys configured as environment variables.")


AWS keys configured as environment variables.


## 2. Downloading files from S3

The datasets are stored on the S3 service provided by Onyxia. The are available on the path   ``projet-statapp-segmedic/diffusion``. You need to download them locally by running the code below. Estimated time: 4 minutes.

In [7]:
# Connexion to  MinIO S3 Onyxia
s3 = s3fs.S3FileSystem(
    client_kwargs={'endpoint_url': 'https://'+'minio.lab.sspcloud.fr'},
    key=os.getenv("AWS_ACCESS_KEY_ID"),
    secret=os.getenv("AWS_SECRET_ACCESS_KEY"),
    token=os.getenv("AWS_SESSION_TOKEN")
)
#print(len(s3.ls("projet-statapp-segmedic/diffusion/nnunet_dataset/nnUNet_raw/Dataset001_Annot1/labelsTr")))

In [None]:

def download_s3_folder():
    
    # Defining paths
    base_local_path = Path('/tmp/nnunet')
    s3_base_path = "projet-statapp-segmedic/diffusion/nnunet_dataset"
    folders = ['nnUNet_raw', 'nnUNet_preprocessed', 'nnUNet_results']
    
    # Creating local folders
    for folder in folders:
        local_folder = base_local_path / folder
        local_folder.mkdir(parents=True, exist_ok=True)
        
        s3_path = f"{s3_base_path}/{folder}"
        print(f"\nTéléchargement du dossier {folder}...")
        
        # Recursive list of all files from S3
        try:
            files = s3.find(s3_path)
            
            # Progression bar (very nice!)
            with tqdm(total=len(files), desc=f"Fichiers dans {folder}") as pbar:
                for file_path in files:
                    relative_path = file_path.replace(s3_path, '').lstrip('/')
                    local_file_path = local_folder / relative_path
                    
                    # Creating local files if needed
                    local_file_path.parent.mkdir(parents=True, exist_ok=True)
                    
                    # Dowloading files
                    if not local_file_path.exists():
                        try:
                            s3.get(file_path, str(local_file_path))
                        except Exception as e:
                            print(f"Error while downloading {file_path}: {e}")
                    
                    pbar.update(1)
        
        except Exception as e:
            print(f"Error while reading {s3_path}: {e}")
            continue
        
        #ERROR CORRECTED: the nnU-Net dataset naming convention requires 4 digit for image case file, not 3. 
        for string in ['1', '2', '3']:
            images = Path(f"/tmp/nnunet/nnUNet_raw/Dataset00{string}_Annot{string}/imagesTr")
            for f in images.glob("*_000.nii.gz"):
                f.rename(f.with_name(f.name.replace("_000.nii.gz", "_0000.nii.gz")))
    
    # Creating global variables for paths, needed for nnU-Net training. 
    env_vars = {
        'nnUNet_raw': str(base_local_path / 'nnUNet_raw'),
        'nnUNet_preprocessed': str(base_local_path / 'nnUNet_preprocessed'),
        'nnUNet_results': str(base_local_path / 'nnUNet_results')
    }
    
    for var_name, path in env_vars.items():
        os.environ[var_name] = path
    
    # Adding to .bashrc
    with open(os.path.expanduser('~/.bashrc'), 'a') as f:
        f.write('\n# nnUNet paths\n')
        for var_name, path in env_vars.items():
            f.write(f'export {var_name}="{path}"\n')
    
    print("\nConfiguration finished. Environment variables created:")
    for var_name, path in env_vars.items():
        print(f"{var_name}={path}")

    #To apply changes:
    !source ~/.bashrc

download_s3_folder()



Téléchargement du dossier nnUNet_raw...


Fichiers dans nnUNet_raw: 100%|██████████| 133/133 [02:05<00:00,  1.06it/s]



Téléchargement du dossier nnUNet_preprocessed...


Fichiers dans nnUNet_preprocessed: 100%|██████████| 253/253 [01:47<00:00,  2.35it/s]



Téléchargement du dossier nnUNet_results...


Fichiers dans nnUNet_results: 100%|██████████| 45/45 [00:18<00:00,  2.42it/s]



Configuration finished. Environment variables created:
nnUNet_raw=/tmp/nnunet/nnUNet_raw
nnUNet_preprocessed=/tmp/nnunet/nnUNet_preprocessed
nnUNet_results=/tmp/nnunet/nnUNet_results


Verify if the downloading has been done successfully by running the following line. 

Expected output: _dataset_fingerprint.json gt_segmentations nnUNetPlans.json dataset.json nnUNetPlans_3d_fullres splits_final.json_

In [12]:
!ls /tmp/nnunet/nnUNet_preprocessed/Dataset002_Annot2

dataset_fingerprint.json  gt_segmentations	  nnUNetPlans.json
dataset.json		  nnUNetPlans_3d_fullres  splits_final.json


If you wish to preprocess and verify the datasets integrity, copy-paste and run the following lines. **BE CAREFUL:** this might make Onyxia crash if you do not increase the CPU and RAM ressources! It also takes more than 20 min per line.The lines have already been run before. You normally do not need to run them. That is why the lines are not in a code cell.

``!nnUNetv2_plan_and_preprocess -h``

``!nnUNetv2_plan_and_preprocess -d Dataset001_Annot1 -c 3d_fullres --verify_dataset_integrity -np 2 -npfp 2``

``!nnUNetv2_plan_and_preprocess -d Dataset002_Annot2 -c 3d_fullres --verify_dataset_integrity -np 2 -npfp 2``

``!nnUNetv2_plan_and_preprocess -d Dataset003_Annot3 -c 3d_fullres --verify_dataset_integrity -np 2 -npfp 2``


(**Optional**) If you wish to upload all the documents stored locally, you can run the following code. Select one folder among ``nnUNet_preprocessed`` or ``nnUNet_results`` (you normally do not need to upload files from nnUNet_raw). Estimated time: between 10s and 1min10s.

In [4]:
def upload_to_s3(folder):
    from pathlib import Path
    from tqdm import tqdm

    # Dossier local et distant
    local_folder = Path(f'/tmp/nnunet/{folder}')
    s3_folder = f"projet-statapp-segmedic/diffusion/nnunet_dataset/{folder}"
    
    # Lister tous les fichiers à uploader
    files = list(local_folder.rglob("*"))
    
    print(f"\nUploading {folder} to {s3_folder}...")
    with tqdm(total=len(files), desc=f"Upload {folder}") as pbar:
        for file_path in files:
            if file_path.is_file():
                relative_path = file_path.relative_to(local_folder)
                s3_path = f"{s3_folder}/{relative_path.as_posix()}"
                try:
                    s3.put(str(file_path), s3_path)
                except Exception as e:
                    print(f"Erreur lors de l'upload de {file_path} → {s3_path}: {e}")
            pbar.update(1)

upload_to_s3(input("Enter nnUNet_preprocessed or nnUNet_results"))


Uploading  to projet-statapp-segmedic/diffusion/nnunet_dataset/...


Upload : 100%|██████████| 466/466 [00:00<00:00, 8549.47it/s]

Erreur lors de l'upload de /tmp/nnunet/nnUNet_raw/.keep → projet-statapp-segmedic/diffusion/nnunet_dataset//nnUNet_raw/.keep: name 's3' is not defined
Erreur lors de l'upload de /tmp/nnunet/nnUNet_preprocessed/.keep → projet-statapp-segmedic/diffusion/nnunet_dataset//nnUNet_preprocessed/.keep: name 's3' is not defined
Erreur lors de l'upload de /tmp/nnunet/nnUNet_results/Dataset002_Annot2/nnUNetTrainer__nnUNetPlans__3d_fullres/plans.json → projet-statapp-segmedic/diffusion/nnunet_dataset//nnUNet_results/Dataset002_Annot2/nnUNetTrainer__nnUNetPlans__3d_fullres/plans.json: name 's3' is not defined
Erreur lors de l'upload de /tmp/nnunet/nnUNet_results/Dataset002_Annot2/nnUNetTrainer__nnUNetPlans__3d_fullres/dataset.json → projet-statapp-segmedic/diffusion/nnunet_dataset//nnUNet_results/Dataset002_Annot2/nnUNetTrainer__nnUNetPlans__3d_fullres/dataset.json: name 's3' is not defined
Erreur lors de l'upload de /tmp/nnunet/nnUNet_results/Dataset002_Annot2/nnUNetTrainer__nnUNetPlans__3d_fullres




## 3. Training

Training must be jointly done with file uploading: The training creates many documents to save progress. These documents are stored locally, but we need them on S3. Given that epochs take usually about 200s, I decided to set the time interval of uploading to 200s.

Decide on which dataset (i.e. which set of annotations) you want to use: ``Dataset001_Annot1``, ``Dataset002_Annot2``, ``Dataset003_Annot3``. 

**CAREFUL**: The project isn't entirely done. For the moment, there is no early stopping. You should continuoulsy check if Onyxia hasn't crashed during the training (normally it shouldn't happen) and stop about 80 epochs. If you wish to resume training, you can enter this: ``nnUNetv2_train <dataset> 3d_fullres all --npz --c`` but it will only resume from a multiple of 50 epochs (nnU-Net automatically saves its results every 50 epochs). 

In [8]:
# Code to train and upload nnU-Net

dataset=input("Enter one among: Dataset001_Annot1, Dataset002_Annot2, Dataset003_Annot3")

local_results_path = Path(f"/tmp/nnunet/nnUNet_results/f{dataset}")
s3_results_path = f"projet-statapp-segmedic/diffusion/nnunet_dataset/nnUNet_results/{dataset}"

# Upload function with time interval 
# more smartly: upload as soon as the content of temp/results changes
def sync_results_to_s3(interval=300):
    print("[Uploader] Starting S3 sync thread.")
    last_upload_time = time.time()

    while True:
        # Call the upload function
        upload_to_s3(local_results_path)

        # Print "upload done" every 300 seconds
        current_time = time.time()
        if current_time - last_upload_time >= 300:
            print("upload done")
            last_upload_time = current_time

        time.sleep(interval)


# Training function
def run_training():
    print("[Trainer] Launching nnUNet training...")
    command = [
        "nnUNetv2_train",
        f"{dataset}",  # Dataset ID
        "3d_fullres",  # Plan
        "all",  # Fold            
        "--npz",
        "--c"
    ]
    subprocess.run(command)
    print("[Trainer] Training complete.")


# Threads
uploader_thread = threading.Thread(target=sync_results_to_s3, daemon=True)
trainer_thread = threading.Thread(target=run_training)

uploader_thread.start()
trainer_thread.start()

trainer_thread.join()
print("[Main] All done.")

[Uploader] Starting S3 sync thread.

Uploading /tmp/nnunet/nnUNet_results/fDataset001_Annot1 to projet-statapp-segmedic/diffusion/nnunet_dataset//tmp/nnunet/nnUNet_results/fDataset001_Annot1...


Upload /tmp/nnunet/nnUNet_results/fDataset001_Annot1: 0it [00:00, ?it/s]

[Trainer] Launching nnUNet training...





nnUNet_raw is not defined and nnU-Net can only be used on data for which preprocessed files are already present on your system. nnU-Net cannot be used for experiment planning and preprocessing like this. If this is not intended, please read documentation/setting_up_paths.md for information on how to set this up properly.
nnUNet_preprocessed is not defined and nnU-Net can not be used for preprocessing or training. If this is not intended, please read documentation/setting_up_paths.md for information on how to set this up.
nnUNet_results is not defined and nnU-Net cannot be used for training or inference. If this is not intended behavior, please read documentation/setting_up_paths.md for information on how to set this up.

############################
INFO: You are using the old nnU-Net default plans. We have updated our recommendations. Please consider using those instead! Read more here: https://github.com/MIC-DKFZ/nnUNet/blob/master/documentation/resenc_presets.md
####################

Traceback (most recent call last):
  File "/usr/local/bin/nnUNetv2_train", line 8, in <module>
    sys.exit(run_training_entry())
             ^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/nnunetv2/run/run_training.py", line 267, in run_training_entry
    run_training(args.dataset_name_or_id, args.configuration, args.fold, args.tr, args.p, args.pretrained_weights,
  File "/usr/local/lib/python3.12/site-packages/nnunetv2/run/run_training.py", line 192, in run_training
    nnunet_trainer = get_trainer_from_args(dataset_name_or_id, configuration, fold, trainer_class_name,
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/site-packages/nnunetv2/run/run_training.py", line 61, in get_trainer_from_args
    preprocessed_dataset_folder_base = join(nnUNet_preprocessed, maybe_convert_to_dataset_name(dataset_name_or_id))
                                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

[Trainer] Training complete.
[Main] All done.
