In [1]:
import os
import sys
import pathlib

import cv2

script_dir = pathlib.Path('./RDN_segmentation_container/MARS/streamlit_apps/')
sys.path.append(str(script_dir))
import csv
import glob
import h5py
import math
import uuid
import json
import yaml
import torch
import base64
import random
import shutil
import pickle
import socket
import difflib
import platform
import subprocess
import numpy as np
from tqdm.notebook import tqdm
import pandas as pd
import SimpleITK as sitk
import concurrent.futures
import torch.nn as nn
import torch.nn.functional as F
import time
from PIL import Image, ImageColor
from torch import optim
from functools import wraps
from typing import Callable, Tuple
from multiprocessing import cpu_count
from multiprocessing.pool import Pool
from torch.utils.data import DataLoader
from adabelief_pytorch import AdaBelief
from datetime import datetime, timedelta
from timeit import default_timer as timer
from random import shuffle as rand_shuffle
from sklearn.utils import shuffle as sk_shuffle

from torch.utils.tensorboard import SummaryWriter
import torchvision


import utils.dataprocess as dp
from net import UNet_Light_RDN
from utils.generate import *
from utils.label_utils import *
from utils.label_utils import _check_label

from utils.train import rdn_train, rdn_val
from utils.dataset import HDF52D, load_patches, natural_keys
from utils.losses import DomainEnrichLoss, dice_loss, DiceOverlap, Accuracy
from streamlit_label_prep import get_patches, get_dirt_bone_patches, random_patches, glob_flat_list, clean_image, _check_label, rescale_label_proper, check_for_match, rescale_intensity, downscale_intensity, read_train_yaml, initiate_cuda, _convert_size, generate_hdf5, generate_patches, _setup_patches, parallelize_ratios, generate_ratios_streamlit_multi, generate_ratios

In [2]:
slice_types = ["tif", "png", "jpg", "jpeg", "bmp", "dcm"]
segmented_dir = pathlib.Path("../../Ressources/data/Training_dataset/Antony/new_labels/")
unsegmented_dir = pathlib.Path("../../Ressources/data/Training_dataset/Antony/new_unseg/")

segmented_imgs = glob_flat_list(search_directory=segmented_dir, file_types=slice_types, unique=True)
unsegmented_imgs = glob_flat_list(search_directory=unsegmented_dir, file_types=slice_types, unique=True)
print("Nombre d'image labelisé :",len(segmented_imgs))
print("Nombre d'image non segmentée :",len(unsegmented_imgs))

Nombre d'image labelisé : 353
Nombre d'image non segmentée : 353


In [3]:
unseg_dir = pathlib.Path("../../Ressources/data/Training_dataset/Antony/new_unseg_training/")
label_dir = pathlib.Path("../../Ressources/data/Training_dataset/Antony/new_labels_training/")
            
if not unseg_dir.exists():
    print("Création du dossier d'entrainement pour les images non segmentées")
    unseg_dir.mkdir()
    print("Traitement des images non-segmentées...")

    img_count = 0
    for i in tqdm(range(len(unsegmented_imgs))):
        clean_image(inputFilename=unsegmented_imgs[i], suffix="", out_name="", out_type="tif", out_dir=unseg_dir, to_streamlit=False)
        img_count += 1

if not label_dir.exists():
    print("Création du dossier d'entrainement pour les images labelisées")
    label_dir.mkdir()
    print("Traitement des images labelisées...")

    img_count = 0
    for i in tqdm(range(len(segmented_imgs))):
        clean_image(inputFilename=segmented_imgs[i], suffix="", out_name="", out_type="tif", out_dir=label_dir, to_streamlit=False)
        img_count += 1


Création du dossier d'entrainement pour les images non segmentées
Traitement des images non-segmentées...


100%|██████████| 353/353 [00:03<00:00, 113.44it/s]


Création du dossier d'entrainement pour les images labelisées
Traitement des images labelisées...


100%|██████████| 353/353 [00:03<00:00, 111.22it/s]


In [4]:
check_list = label_dir.rglob("*.tif")
print("Vérification qu'il y a bien 3 classes dans les images labelisées...")
for check in check_list:
    checking = sitk.ReadImage(str(check))
    checked = _check_label(inputImage=checking, expected_classes=3,to_streamlit=False)
    if checked == False:
        print(f"Remise à l'échelle des niveaux de gris de {check} correspondant aux classes (niveaux de gris = 0, 128, and 255...)")
        rescale_label_proper(input_image=checking, input_file_name=check)

Vérification qu'il y a bien 3 classes dans les images labelisées...


In [5]:
unsegmented_names = glob.glob(str(unseg_dir.joinpath("*.tif")))
print(f"Nombre d'images non segmentées .tif trouvées : {len(unsegmented_names)}")

segmented_names = glob.glob(str(label_dir.joinpath("*.tif")))
print(f"Nombre d'images labelisées .tif trouvées : {len(segmented_names)}")

match_list = segmented_names
unsegmented_file_list = unsegmented_names

if len(match_list) != len(unsegmented_file_list):
    print(f"Attention, il n'y a pas le même nombre d'images labelisées ({len(match_list)}) et d'images non segmentée ({len(unsegmented_file_list)}) !!!")
    print("Cela peut être dû au fait que l'étape de normalisation a échoué pour une raison quelconque (types de fichiers non pris en charge) ou les dossiers sources ne contiennent pas les images d'écriture. Veuillez les vérifier et réessayer ")

match_list = [pathlib.Path(label_name).parts[-1] for label_name in match_list]
unsegmented_file_list = [pathlib.Path(unseg_name).parts[-1] for unseg_name in unsegmented_file_list]
               
#Nous ne voulons prendre en compte que les noms de fichiers qui n'ont pas de correspondance
match_list = [label_name for label_name in match_list if label_name not in unsegmented_file_list]      

Nombre d'images non segmentées .tif trouvées : 353
Nombre d'images labelisées .tif trouvées : 353


In [6]:
print("Nom des images non segmentées: ", unsegmented_file_list)

Nom des images non segmentées:  ['363_rec0459.tif', '363_rec0459_downscaled.tif', '363_rec0459_rescaled.tif', '366_rec0486.tif', '366_rec0486_downscaled.tif', '366_rec0486_rescaled.tif', 'AF14501414FemurDistu_UnSeg_YM.tif', 'AF14501414FemurDistu_UnSeg_YM_downscaled.tif', 'AF14501414FemurDistu_UnSeg_YM_rescaled.tif', 'AVO74_Hawk_headed_Parrot_x_758.tif', 'AVO74_Hawk_headed_Parrot_x_758_downscaled.tif', 'AVO74_Hawk_headed_Parrot_x_758_rescaled.tif', 'AVO74_Hawk_headed_Parrot_y_780.tif', 'AVO74_Hawk_headed_Parrot_y_780_downscaled.tif', 'AVO74_Hawk_headed_Parrot_y_780_rescaled.tif', 'AVO74_Hawk_headed_Parrot_z_496.tif', 'AVO74_Hawk_headed_Parrot_z_496_downscaled.tif', 'AVO74_Hawk_headed_Parrot_z_496_rescaled.tif', 'BE93TibiaL_UnSeg_ZM.tif', 'BE93TibiaL_UnSeg_ZM_downscaled.tif', 'BE93TibiaL_UnSeg_ZM_rescaled.tif', 'BeliManastirG2Talus_YM.tif', 'BeliManastirG2Talus_YM_downscaled.tif', 'BeliManastirG2Talus_YM_rescaled.tif', 'BeliManastirG2Talus_ZM2.tif', 'BeliManastirG2Talus_ZM2_downscaled.ti

In [7]:
print("Nom des images non labelisées: ", match_list)

Nom des images non labelisées:  ['RADII_NF_LR_DIST_UnSeg.tif', 'VeliaT300Talus_Unseg_XM.tif', 'VeliaT300Talus_Unseg_XM_downscaled.tif', 'VeliaT300Talus_Unseg_XM_rescaled.tif']


In [8]:
print("Vérification du dataset...")
if match_list:
    unmatched_labels = {}
    for label_name in match_list:
        if not unseg_dir.joinpath(label_name).is_file():
            file_type_check = glob.glob(str(unseg_dir.joinpath(label_name.rsplit(".")[0])))
            if len(file_type_check) == 1:
                unmatched_labels[str(label_name)] = file_type_check
            elif len(file_type_check) > 1:
                print(f"Attention : Plusieurs liens détectés pour {label_name}")
                print(file_type_check)
                unmatched_labels[str(label_name)] = file_type_check
            else:
                possible_match = check_for_match(missing_file=label_name, check_filelist=unsegmented_file_list, to_streamlit=False)
                if possible_match == None:
                    unmatched_labels[str(label_name)] = ["None"]
                else:
                    unmatched_labels[str(label_name)] = possible_match
    if len(unmatched_labels) == 0:
        print("Aucun fichier mal nommé n'a été trouvé !")
        unmatched_labels = pd.DataFrame()
    else:
        unmatched_labels = pd.DataFrame.from_dict(unmatched_labels,  orient='index')
        unmatched_labels.reset_index(drop=False, inplace=True)
        if unmatched_labels.shape[1] > 2:
            columns_names = [f"unsegmented_name_match_{column_num}" for column_num in unmatched_labels.columns[1:]]
            columns_names = ["label_name"] + columns_names
            unmatched_labels.columns = columns_names
        else:
            unmatched_labels.columns = ["label_name", "unsegmented_name_match"]
else:
    print("Aucun fichier mal nommé n'a été trouvé !")


Vérification du dataset...
Aucun fichier mal nommé n'a été trouvé !


In [9]:
print("Data Augmentation : ")
print("Augmentation du nombre de données pour l'entraînement : ")

training_list = [item.name for item in unseg_dir.rglob("*.tif")]
training_list = [item for item in training_list if "_rescaled.tif" not in item]
training_list = [item for item in training_list if "_downscaled.tif" not in item]


label_list = [item.name for item in label_dir.rglob("*.tif")]
label_list = [item for item in label_list if "_rescaled.tif" not in item]
label_list = [item for item in label_list if "_downscaled.tif" not in item]


for i in tqdm(range(len(training_list))):
    up_check = training_list[i].replace(".tif", "_rescaled.tif")
    down_check = training_list[i].replace(".tif", "_downscaled.tif")
    out_dir = unseg_dir

    if not unseg_dir.joinpath(up_check).exists():
        rescale_intensity(inputFilename=out_dir.joinpath(training_list[i]), writeOut=True, file_type="tif", outDir=out_dir)
        shutil.copy(str(label_dir.joinpath(training_list[i])), str(label_dir.joinpath(up_check)))

    if not unseg_dir.joinpath(down_check).exists():
        downscale_intensity(inputFilename=out_dir.joinpath(training_list[i]), downscale_value=75, writeOut=True, file_type="tif", outDir=out_dir)
        shutil.copy(str(label_dir.joinpath(training_list[i])), str(label_dir.joinpath(down_check)))

print("Nouvelles images non-segmentées: ",training_list)
print("Nouvelles images labelisées: ",label_list)

Data Augmentation : 
Augmentation du nombre de données pour l'entraînement : 


  0%|          | 0/136 [00:00<?, ?it/s]

100%|██████████| 136/136 [00:06<00:00, 21.88it/s]

Nouvelles images non-segmentées:  ['363_rec0459.tif', '366_rec0486.tif', 'AF14501414FemurDistu_UnSeg_YM.tif', 'AVO74_Hawk_headed_Parrot_x_758.tif', 'AVO74_Hawk_headed_Parrot_y_780.tif', 'AVO74_Hawk_headed_Parrot_z_496.tif', 'BE93TibiaL_UnSeg_ZM.tif', 'BeliManastirG2Talus_YM.tif', 'BeliManastirG2Talus_ZM2.tif', 'BE_106_Humerus_Prox_R_UnSeg_reoriented_sphereVOI_440.tif', 'BE_111_Humerus_Prox_L_UnSeg_reoriented_sphereVOI_590.tif', 'BE_141A_Femur_Prox_R_UnSeg_reoriented_sphereVOI_432.tif', 'BE_33_Humerus_Prox_R_UnSeg_Y_110.tif', 'BE_37_Humerus_Prox_L_UnSeg_Y_630.tif', 'BE_69_Humerus_Prox_R_UnSeg_X_786.tif', 'BE_84_Femur_Prox_L_UnSeg_Z_113.tif', 'BE_91_Tibia_Dist_R.tif', 'BE_93_Humerus_Prox_R_UnSeg_reoriented_sphereVOI_625.tif', 'BE_99_Humerus_Prox_R_UnSeg_reoriented_sphereVOI_629.tif', 'CMN29181HumerusU_UnSeg_XM.tif', 'Dart_Stw311Femur_Prox.tif', 'Dart_Stw311Femur_Prox3.tif', 'Dart_Stw311Femur_Prox3_10.tif', 'Dart_Stw311Femur_Prox4.tif', 'Dickson_Mounds_278CalcaneusR_UnSeg_01.tif', 'DM_278




In [10]:
data_path = unseg_dir.parent.parent.joinpath("data")
train_yaml = str(script_dir.joinpath("yaml").joinpath("train.yaml"))
test_yaml = str(script_dir.joinpath("yaml").joinpath("test.yaml"))
new_models_path = data_path.joinpath("new_model")


optimizer = "Adam"

train_from_yaml = False

if train_from_yaml and script_dir.joinpath("yaml").joinpath("train.yaml").exists() and script_dir.joinpath("yaml").joinpath("test.yaml").exist:
    #Lecture les informations yaml par défaut
    train_yaml_file = read_train_yaml(yaml_file=train_yaml)
    test_yaml_file = read_train_yaml(yaml_file=test_yaml)

    #Grab the information from the default yamls.
    previous_batch = train_yaml_file["data_loader"]["batch_size"]
    previous_period = train_yaml_file["period"]
    previous_epoch = train_yaml_file["train_param"]["Epoch"]
    previous_learning_rate = train_yaml_file["optimizer"]["lr"]
    previous_weight_decay = train_yaml_file["optimizer"]["weight_decay"]

    
    batch_size = int(previous_batch)
    period_size = int(previous_period)
    weight_decay = float(previous_weight_decay)
    epsilon = 10**(-8)
    epochs_num = int(previous_epoch)
    learning_rate = float(previous_learning_rate)
    
else:
    batch_size = 32
    period_size = 8
    weight_decay = 0.1
    epsilon = 10**(-8)
    epochs_num = 60
    learning_rate = 10**(-3)


#Obtention le périphérique GPU
if torch.cuda.is_available() and torch.cuda.device_count() >= 1:
    use_gpu = 0
    print("GPU trouvé !")
    torch.cuda.set_device(use_gpu)
    cuda_mem = int(torch.cuda.get_device_properties(device=use_gpu).total_memory)
    cuda_mem = list(_convert_size(sizeBytes=cuda_mem))
    print(f"GPU réglé sur le numéro de l'appareil {use_gpu}: {torch.cuda.get_device_properties(device=use_gpu).name} ayant {cuda_mem[0]} de VRAM")
else:
    print("Pas de GPU !")

print(f"Training data save path: {data_path}")
print(f"New models will be saved to: {new_models_path}")


train_from_previous = False

if train_from_previous:
    previous_model =  pathlib.Path("Chemin vers le modèle")
    pre_trained = torch.load(previous_model, map_location=f'cuda:{int(use_gpu)}')



GPU trouvé !
GPU réglé sur le numéro de l'appareil 0: Quadro RTX 5000 ayant 16.0 GB de VRAM
Training data save path: ..\..\Ressources\data\Training_dataset\data
New models will be saved to: ..\..\Ressources\data\Training_dataset\data\new_model


In [11]:
# print("Save the training parameters in a YAML file")

# if not data_path.exists():
#     data_path.mkdir()

# if train_from_previous:
#     pretrained_path = data_path.joinpath("pretrained_model")
#     if not pretrained_path.exists():
#         pretrained_path.mkdir()
#     model_name = str(pretrained_path.joinpath("pretrained_model.pth").as_posix())
#     torch.save(pre_trained, str(model_name))
#     train_yaml_file["model"]["if_pre_train"] = "true"
#     train_yaml_file["model"]["path"] = str(model_name)
# else:
#     train_yaml_file["model"]["if_pre_train"] = "false"
#     train_yaml_file["model"]["path"] = None

# #GPU device index
# train_yaml_file["gpu_config"]["gpu_name"] = int(use_gpu)

# #Loops
# train_yaml_file["data_loader"]["batch_size"] = int(batch_size)
# if period_size:
#     train_yaml_file["period"] = int(period_size)
# else:
#     train_yaml_file["period"] = None
# train_yaml_file["train_param"]["Epoch"] = int(epochs_num)

# #Optimizer paramters
# train_yaml_file["optimizer"]["method"] = f'{optimizer}'
# train_yaml_file["optimizer"]["lr"] = float(learning_rate)
# if optimizer == "AdaBelief":
#     train_yaml_file["optimizer"]["epsilon"] = epsilon
#     train_yaml_file["optimizer"]["weight_decay"] = "false"
# else:
#     train_yaml_file["optimizer"]["weight_decay"] = float(weight_decay)

# #Data path
# train_yaml_file["path"]["data_path"] = str(data_path.joinpath("dataset.hdf5").as_posix())
# train_yaml_file["path"]["save_path"] = str(new_models_path.as_posix())

# #CSV path
# train_yaml_file["csv_path"]["train"] = str(data_path.joinpath("patches.csv").as_posix())
# train_yaml_file["csv_path"]["val"] = str(data_path.joinpath("val.csv").as_posix())
# train_yaml_file["csv_path"]["ratios"] = str(data_path.joinpath("ratios.csv").as_posix())

# #Test yaml
# test_yaml_file["gpu_config"]["gpu_name"] = int(state.use_gpu)
# test_yaml_file["path"]["data_path"] = str(data_path.joinpath("dataset.hdf5").as_posix())
# test_yaml_file["model"]["path"] = str(new_models_path.as_posix())

# test_yaml_file["csv_path"]["val"] = str(data_path.joinpath("val.csv").as_posix())

# new_yaml_path = data_path.joinpath("yaml")
# new_train_yaml_name = new_yaml_path.joinpath("train.yaml")
# new_test_yaml_name = new_yaml_path.joinpath("test.yaml")
# if not new_yaml_path.exists():
#     new_yaml_path.mkdir()

# with open(str(new_train_yaml_name), 'w') as f:
#     yaml.dump(train_yaml_file, f)
#     new_train_yaml_name = new_train_yaml_name

# with open(str(new_test_yaml_name), 'w') as f:
#     yaml.dump(test_yaml_file, f)
#     new_test_yaml_name = new_test_yaml_name
    
# print(f"Les paramètres d'entraînement et de validation ont été enregistrés : {new_yaml_path}")


In [12]:
print("Finalize data")
data_path = unseg_dir
if data_path in [None, "None", "."]:
    print("Dataset path not defined")
elif unsegmented_imgs!=None and segmented_imgs!=None:
    training_data_dir = unseg_dir.as_posix()
    training_label_dir = label_dir.as_posix()
    hdf5_name = str(data_path.joinpath("dataset.hdf5").as_posix())
    patches_name = str(data_path.joinpath("patches.csv").as_posix())
    val_name = str(data_path.joinpath("val.csv").as_posix())
    ratios_name = str(data_path.joinpath("ratios.csv").as_posix())

    class_num = 3 #Maybe more or less classes later.

    print(f"Répertoire des données d'entraînement: {training_data_dir}")
    print(f"Répertoire des données labelisées d'entraînement: {training_label_dir}")
    print(f"Répertoire des données labelisées d'entraînement: {hdf5_name}")

    
    generate_hdf5(data_dir=unseg_dir, label_dir=label_dir, save_name=hdf5_name)
    hdf5 = hdf5_name
    print(hdf5_name)
    stride = 32
    train_size = 0.7
    
    output = 256
    train_names, val_names = generate_patches(hdf5_file=hdf5_name, patches_csv=patches_name,
                                                    validation_csv=val_name, train_ratio=train_size,
                                                    stride=stride, output_size=output, always_train_csv=False)
          
    ratios_parallel = True

    if ratios_parallel:
        cpus_avail = list(range(cpu_count() + 1))
        num_threads = cpu_count() - 1

    start_multi = timer()
    if ratios_parallel:
        patches_df = _setup_patches(patches_csv=patches_name)
        new_ratios = parallelize_ratios(df=patches_df, func=generate_ratios_streamlit_multi, hdf5_file=hdf5_name, class_num=3, n_cores=num_threads, to_streamlit=False)
    else:
        ratios = generate_ratios(hdf5_file=hdf5_name, patches_csv=patches_name, class_num=class_num)
        new_ratios = pd.DataFrame(ratios)
        ratio_headers = [f"Class {idx}" for idx in range(class_num)] #Just in case we do increase class numbers
        new_ratios.columns = ratio_headers

    class_means = pd.DataFrame(new_ratios.mean()).T
    class_means.columns = ["Air", "Non-Bone", "Bone"]
    print("Training data class percentages: ",class_means)
    new_ratios.to_csv(f"{ratios_name}", index=False)

Finalize data
Répertoire des données d'entraînement: ../../Ressources/data/Training_dataset/Antony/new_unseg_training
Répertoire des données labelisées d'entraînement: ../../Ressources/data/Training_dataset/Antony/new_labels_training
Répertoire des données labelisées d'entraînement: ../../Ressources/data/Training_dataset/Antony/new_unseg_training/dataset.hdf5
Unable to find a match for BE_33_Humerus_Prox_R_UnSeg_Y_110_downscaled. Please check if the cases, underscores, etc. match.
Unable to find a match for BE_37_Humerus_Prox_L_UnSeg_Y_630_downscaled. Please check if the cases, underscores, etc. match.
Unable to find a match for BE_69_Humerus_Prox_R_UnSeg_X_786_downscaled. Please check if the cases, underscores, etc. match.
Unable to find a match for BE_99_Humerus_Prox_R_UnSeg_reoriented_sphereVOI_629_downscaled. Please check if the cases, underscores, etc. match.
Unable to find a match for NF_33_819951_Talus_Whole_R1321_downscaled. Please check if the cases, underscores, etc. match.
U

2023-07-10 16:09:48.684 
  command:

    streamlit run c:\Users\n.vanderesse\AppData\Local\ia-sereos_env\lib\site-packages\ipykernel_launcher.py [ARGUMENTS]


Calcul avec 47 coeur(s).
Training data class percentages:          Air  Non-Bone      Bone
0  0.822593  0.066911  0.110495


In [13]:

#Should probably move this up to the yaml section and put it into state

n_channel = 1
model_path = ""
net = UNet_Light_RDN(n_channels=1, n_classes=3)
if train_from_previous:              
    if use_gpu:
        net.load_state_dict(torch.load(model_path,
                                    map_location=torch.device(type='cuda',
                                                                index=use_gpu)))
    else:
        net.load_state_dict(torch.load(model_path))


if str(optimizer) == "AdaBelief":
    optimizer = AdaBelief(net.parameters(),
                        lr=float(learning_rate),
                        eps=float(epsilon),
                        betas=(0.9, 0.999),
                        weight_decouple=True,
                        rectify=False)
else:
    optimizer = getattr(optim,
                        str("Adam"))(net.parameters(),
                                                    lr=learning_rate,
                                                    weight_decay=weight_decay)
    #learning rate schedule
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=period_size, gamma=0.1)



# get training Epoch
Epoch = epochs_num

# load patches and ratios

    # This gets used later to ramp up the amount of non-bone that is being thrown into the training.
    # May have to think of a fancy way to get an equivelant number with Adabelief, but it is being set
    # to the default Adam period for now.
period = 8
train_patches = get_patches(hdf5_file=hdf5_name, train_names=train_names, stride=stride, output_size=output)
val_patches = get_patches(hdf5_file=hdf5_name, train_names=val_names, stride=stride, output_size=output)
ratios = new_ratios

# create train transform
train_transform = transforms.Compose([dp.Augmentation(output_size=output),
                                    dp.AdjustMask(class_num=3),
                                    dp.Normalize(max=255, min=0),
                                    dp.ToTensor()])
val_transform = transforms.Compose([dp.AdjustMask(class_num=3),
                                    dp.Normalize(max=255, min=0),
                                    dp.ToTensor()])

epoch_count = 0
print(f"Epoch progress:")
print(f"Progress training {Epoch} epochs...")
total_timer = timer()
iteration = 0
for i_epoch in range(Epoch):
    print(f"Epoch {epoch_count + 1} of {Epoch}")
    if i_epoch < period:
        dirt_rate = 0.5
    elif i_epoch < 2 * period and i_epoch >= period:
        dirt_rate = 0.3
    elif i_epoch < 3 * period and i_epoch >= 2 * period:
        dirt_rate = 0.1
    else:
        dirt_rate = 0.0

    # Domain enrich patches
    # Makes a decision about the lowest percent dirt that can be considered for the training.
    new_patches = random_patches(dirt_choose_threshold=0.1, dirt_rate=dirt_rate,
                                patches=train_patches, ratios=ratios)

    rdn_patches, index = get_dirt_bone_patches(train_patches, ratios)

    data_set1 = HDF52D(data_path.joinpath("dataset.hdf5"), new_patches, val_patches,
                    train_transform=train_transform,
                    val_transform=val_transform)

    data_set2 = HDF52D(data_path.joinpath("dataset.hdf5"), rdn_patches, val_patches,
                    train_transform=train_transform,
                    val_transform=val_transform,
                    train_idx=index)

    train_data_loader = []

    current_batch = int(batch_size)


    train_data_loader.append(DataLoader(dataset=data_set1,
                                        batch_size=current_batch,
                                        shuffle=True,
                                        num_workers=0))


    train_data_loader.append(DataLoader(dataset=data_set2,
                                        batch_size=current_batch,
                                        shuffle=True,
                                        num_workers=0))


    print(f"learning rate {optimizer.param_groups[0]['lr']:.6f}")

    rdn_train(net, optimizer, train_data_loader, epoch=i_epoch,
            total_epoch=Epoch, use_gpu=use_gpu, tensorboard_plot=True)
    #lr_scheduler.step()

    # validating
    val_loss, class_val = rdn_val(net, data_set1,
                                use_gpu=use_gpu,
                                i_epoch=i_epoch,
                                class_num=3)

    # save model
    save_name = out_dir.joinpath(f"Loss-{epoch_count}_{val_loss:.6f}.pth")
    torch.save(net.state_dict(), save_name)
    class_val = pd.DataFrame(class_val)
    class_val.columns = ["Class Dice overlap"]
    epoch_count += 1
    iteration = np.floor((100 * epoch_count) / int(Epoch))

print('Training is finished !!')



Epoch progress:
Progress training 60 epochs...
Epoch 1 of 60
There are 101459 bone and 27690 dirt patches in the training data...
learning rate 0.001000


Epoch:1/60:   0%|          | 0/120501 [00:00<?, ? batches/s]

image tensor([[[[0.1490, 0.1451, 0.1529,  ..., 0.1529, 0.1451, 0.1451],
          [0.1569, 0.1490, 0.1529,  ..., 0.1529, 0.1412, 0.1490],
          [0.1608, 0.1490, 0.1451,  ..., 0.1647, 0.1529, 0.1529],
          ...,
          [0.1686, 0.1804, 0.1647,  ..., 0.1451, 0.1412, 0.1608],
          [0.1569, 0.1686, 0.1647,  ..., 0.1451, 0.1490, 0.1529],
          [0.1608, 0.1686, 0.1686,  ..., 0.1529, 0.1569, 0.1608]]],


        [[[0.7569, 0.7098, 0.6745,  ..., 0.8353, 0.8314, 0.8431],
          [0.8000, 0.7608, 0.7137,  ..., 0.8275, 0.8431, 0.8510],
          [0.8314, 0.8039, 0.7686,  ..., 0.8431, 0.8471, 0.8471],
          ...,
          [0.6863, 0.6980, 0.7333,  ..., 0.0000, 0.0000, 0.0000],
          [0.6902, 0.7176, 0.7569,  ..., 0.0000, 0.0000, 0.0000],
          [0.7098, 0.7451, 0.7686,  ..., 0.0000, 0.0000, 0.0000]]],


        [[[0.0039, 0.0039, 0.0078,  ..., 0.0039, 0.0078, 0.0039],
          [0.0000, 0.0039, 0.0039,  ..., 0.0039, 0.0000, 0.0000],
          [0.0039, 0.0000, 0.000

Epoch:1/60:   1%|▏         | 1632/120501 [17:28<21:12:30,  1.56 batches/s, loss=0.72058195, loss1=505.30927, loss2=0.67005104] 


KeyboardInterrupt: 