diff --git a/bash/deploy_win64.sh b/bash/deploy_win64.sh new file mode 100644 index 0000000..60213d1 --- /dev/null +++ b/bash/deploy_win64.sh @@ -0,0 +1,7 @@ + +# create a venv : +python -m venv NAME_OF_YOUR_VIRTUAL_ENV +# then install : +pip install pyinstaller python-tk numpy SimpleITK scikit-image tqdm paramiko pyyaml matplotlib +# run the pyinstaller (with console) and exclude modules +pyinstaller -F --clean --upx-dir PATH_TO_UPX PATH_TO_GUI\gui.py --exclude=biom3d --exclude=torch --hidden-import='PIL._tkinter_finder' \ No newline at end of file diff --git a/bash/run_eval.sh b/bash/run_eval.sh index bf0b56b..361130f 100755 --- a/bash/run_eval.sh +++ b/bash/run_eval.sh @@ -11,16 +11,21 @@ # --dir_lab data/msd/Task07_Pancreas/labelsTr_test\ # --num_classes 2 -# python -m biom3d.eval\ -# --dir_pred data/btcv/Testing_small/preds/20230522-182916-unet_default\ -# --dir_lab data/btcv/Testing_small/label\ -# --num_classes 13 - python -m biom3d.eval\ - --dir_pred data/nucleus/official/test/preds/20230908-202124-nucleus_official_fold4\ - --dir_lab data/nucleus/official/test/msk\ + --dir_pred data/msd/Task06_Lung/preds/20230531-092023-unet_lung\ + --dir_lab data/msd/Task06_Lung/labelsTr_test\ --num_classes 1 +# python -m biom3d.eval\ +# --dir_pred data/msd/Task07_Pancreas/preds/20230523-105736-unet_default\ +# --dir_lab data/msd/Task07_Pancreas/labelsTr_test\ +# --num_classes 2 + +# python -m biom3d.eval\ +# --dir_pred data/nucleus/official/test/preds/20230908-202124-nucleus_official_fold4\ +# --dir_lab data/nucleus/official/test/msk\ +# --num_classes 1 + # python -m biom3d.eval\ # --dir_pred data/mito/test/pred/20230203-091249-unet_mito\ # --dir_lab data/mito/test/msk\ diff --git a/bash/run_pred.sh b/bash/run_pred.sh index c1b4f6e..ed96211 100755 --- a/bash/run_pred.sh +++ b/bash/run_pred.sh @@ -112,18 +112,25 @@ # --dir_in "data/nucleus/aline_bug/img"\ # --dir_out "data/nucleus/aline_bug/preds" +# python -m biom3d.pred\ +# --name seg\ +# --log logs/20240219-100225-reims_full_fold0\ +# --dir_in data/reims/big_stack/img\ +# --dir_out data/reims/big_stack/preds\ + python -m biom3d.pred\ - --name seg\ - --log logs/20240219-100225-reims_full_fold0\ - --dir_in data/reims/big_stack/img\ - --dir_out data/reims/big_stack/preds\ + --name seg_eval\ + --log logs/20240319-094430-reims_large_full_no47_fold0\ + --dir_in data/reims/large/test/img\ + --dir_out data/reims/large/test/preds\ + --dir_lab data/reims/large/test/msk # python -m biom3d.pred\ # --name seg_eval\ -# --log logs/20240218-072550-reims_fold0\ -# --dir_in data/reims/test_match/img\ -# --dir_out data/reims/test_match/preds\ -# --dir_lab data/reims/test_match/msk +# --log logs/20230501-153638-unet_default\ +# --dir_in data/mito/test/img\ +# --dir_out data/mito/test/pred\ +# --dir_lab data/mito/test/msk # python -m biom3d.pred\ # --name seg_single\ diff --git a/docs/_static/image/gui_remote_predict.png b/docs/_static/image/gui_remote_predict.png new file mode 100644 index 0000000..2c599fb Binary files /dev/null and b/docs/_static/image/gui_remote_predict.png differ diff --git a/docs/_static/image/gui_remote_predict_omero.png b/docs/_static/image/gui_remote_predict_omero.png new file mode 100644 index 0000000..38dd990 Binary files /dev/null and b/docs/_static/image/gui_remote_predict_omero.png differ diff --git a/docs/_static/image/gui_remote_train.PNG b/docs/_static/image/gui_remote_train.PNG deleted file mode 100644 index 58a6aed..0000000 Binary files a/docs/_static/image/gui_remote_train.PNG and /dev/null differ diff --git a/docs/_static/image/gui_remote_train.png b/docs/_static/image/gui_remote_train.png new file mode 100644 index 0000000..2c82de4 Binary files /dev/null and b/docs/_static/image/gui_remote_train.png differ diff --git a/docs/_static/image/gui_splash.png b/docs/_static/image/gui_splash.png new file mode 100644 index 0000000..4683457 Binary files /dev/null and b/docs/_static/image/gui_splash.png differ diff --git a/docs/biom3d_colab.ipynb b/docs/biom3d_colab.ipynb index 430fe76..05cbe7a 100644 --- a/docs/biom3d_colab.ipynb +++ b/docs/biom3d_colab.ipynb @@ -1 +1 @@ -{"cells":[{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"P8jnH3QZH9L4"},"outputs":[],"source":["#@markdown ## First check that your Runtime is in GPU mode (you can run this cell to do so)\n","\n","#@markdown If not, go to `Runtime` > `Change runtime type` > `Hardware accelerator` > `GPU`\n","\n","!nvidia-smi"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"NhwL5qt9fdg5"},"outputs":[],"source":["#@markdown ##Install Biom3d and import the necessary Python library\n","!pip3 install biom3d==0.0.31\n","\n","import os\n","import biom3d\n","from biom3d.train import train\n","from biom3d.pred import pred\n","from biom3d.eval import eval"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"Relk7WNAkTzM"},"outputs":[],"source":["#@markdown ##Mount your Google drive\n","\n","from google.colab import drive\n","drive.mount('/content/gdrive')"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"3RPq3yPxlNmj"},"outputs":[],"source":["#@markdown ## Define input and output folders\n","\n","#@markdown **Where training data will be saved?** This will include two folders: the config folder (containing configuration files) and the logs folder (containing the model folders)\n","\n","biom3d_dir = '/content/gdrive/MyDrive/biom3d'#@param {type:\"string\"}\n","model_name = \"unet_default\" #@param {type:\"string\"}\n","\n","#@markdown **Where are your images and masks?**\n","img_dir = \"\" #@param {type:\"string\"}\n","msk_dir = \"\" #@param {type:\"string\"}\n","\n","#@markdown **How many classes are there in your annotations?**\n","\n","num_classes = 1 #@param {type:\"number\"}\n","\n","#@markdown **How long would you like to train your model?**\n","num_epochs = 2 #@param {type:\"number\"}"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"EkLs_MoFtnLv"},"outputs":[],"source":["#@markdown ## Preprocess your data before training\n","config_path = biom3d.preprocess.auto_config_preprocess(\n"," img_dir=img_dir,\n"," msk_dir=msk_dir,\n"," num_classes=num_classes,\n"," desc=model_name,\n"," config_dir=os.path.join(biom3d_dir, \"configs/\"),\n"," base_config=None,\n"," ct_norm=False,\n"," num_epochs=num_epochs,\n"," logs_dir=os.path.join(biom3d_dir, \"logs/\"),\n",")"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"fn639Efzrfwe"},"outputs":[],"source":["#@markdown ## Start the training\n","\n","#@markdown If you want to use of different configuration file or use an existing one, complete the following field. **Leave this field empty if you want to use the config file obtained during the preprocessing.**\n","\n","custom_config_path = \"\" #@param {type:\"string\"}\n","\n","if len(custom_config_path) > 0:\n"," config_path = custom_config_path\n","\n","print(\"We will the following configuration file for training:\", config_path)\n","builder = train(config=config_path)"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"4eKiLPQnsokH"},"outputs":[],"source":["#@markdown ## Make predictions with your trained model\n","\n","#@markdown **Path to the model folder: leave it empty if you want to use the model obtained during the previous training step.** This should look like `/content/gdrive/MyDrive/biom3d/logs/20230602-162331-unet_default_fold0`.\n","\n","custom_log_path = \"\" #@param {type:\"string\"}\n","\n","if len(custom_log_path) > 0:\n"," log=custom_log_path\n","else:\n"," assert 'builder' in locals().keys(), \"No existing model folder found. Please complete the `custom_log_path` field or train a model.\"\n"," log = builder.base_dir\n","\n","#@markdown **Prediction input directory:**\n","pred_dir_in = \"\" #@param {type:\"string\"}\n","\n","#@markdown **Prediction output directory:** (where the prediction masks will be stored)\n","pred_dir_out = \"\"#@param {type:\"string\"}\n","\n","dir_out = pred(\n"," log=log,\n"," dir_in=pred_dir_in,\n"," dir_out=pred_dir_out\n"," )\n"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"8l8FPDb_13iI"},"outputs":[],"source":["#@markdown ## Evaluate your model on a test set\n","\n","#@markdown **Prediction output directory:** where the previous prediction masks have been stored and the number of classes in your images. Leave the default values if you want to use the previous prediction path.\n","new_pred_dir_out = \"\"#@param {type:\"string\"}\n","new_num_classes = 0 #@param {type:\"number\"}\n","\n","if new_num_classes == 0:\n"," assert 'num_classes' in locals().keys(), \"Number of classes equal to zero and the previous number of classes does not exist. Please provide one.\"\n"," new_num_classes = num_classes\n","\n","if len(new_pred_dir_out) == 0:\n"," assert len(dir_out) > 0, \"Prediction path seems to be empty and no previous path detected.\"\n"," new_pred_dir_out = dir_out\n","\n","#@markdown **Path to test masks** The test masks must correspond to the predictions.\n","test_msk_dir = \"\" #@param {type:\"string\"}\n","\n","eval(\n"," dir_lab=test_msk_dir,\n"," dir_out=new_pred_dir_out,\n"," num_classes=new_num_classes\n",")"]}],"metadata":{"accelerator":"GPU","colab":{"authorship_tag":"ABX9TyPrfgBkQIScxLMN00pfvqBQ","gpuType":"T4","provenance":[]},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"name":"python"}},"nbformat":4,"nbformat_minor":0} +{"cells":[{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"P8jnH3QZH9L4"},"outputs":[],"source":["#@markdown ## First check that your Runtime is in GPU mode (you can run this cell to do so)\n","\n","#@markdown If not, go to `Runtime` > `Change runtime type` > `Hardware accelerator` > `GPU`\n","\n","!nvidia-smi"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"NhwL5qt9fdg5"},"outputs":[],"source":["#@markdown ##Install Biom3d and import the necessary Python library\n","!pip3 install biom3d==0.0.40 torchio deprecated --no-deps\n","!pip3 install SimpleITK paramiko netcat appdirs\n","\n","import os\n","import biom3d\n","from biom3d.train import train\n","from biom3d.pred import pred\n","from biom3d.eval import eval"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"Relk7WNAkTzM"},"outputs":[],"source":["#@markdown ##Mount your Google drive\n","\n","from google.colab import drive\n","drive.mount('/content/gdrive')"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"3RPq3yPxlNmj"},"outputs":[],"source":["#@markdown ## Define input and output folders\n","\n","#@markdown **Where training data will be saved?** This will include two folders: the config folder (containing configuration files) and the logs folder (containing the model folders)\n","\n","biom3d_dir = '/content/gdrive/MyDrive/biom3d'#@param {type:\"string\"}\n","model_name = \"unet_default\" #@param {type:\"string\"}\n","\n","#@markdown **Where are your images and masks?**\n","img_dir = \"\" #@param {type:\"string\"}\n","msk_dir = \"\" #@param {type:\"string\"}\n","\n","#@markdown **How many classes are there in your annotations?**\n","\n","num_classes = 1 #@param {type:\"number\"}\n","\n","#@markdown **How long would you like to train your model?**\n","num_epochs = 2 #@param {type:\"number\"}"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"EkLs_MoFtnLv"},"outputs":[],"source":["#@markdown ## Preprocess your data before training\n","config_path = biom3d.preprocess.auto_config_preprocess(\n"," img_dir=img_dir,\n"," msk_dir=msk_dir,\n"," num_classes=num_classes,\n"," desc=model_name,\n"," config_dir=os.path.join(biom3d_dir, \"configs/\"),\n"," base_config=None,\n"," ct_norm=False,\n"," num_epochs=num_epochs,\n"," logs_dir=os.path.join(biom3d_dir, \"logs/\"),\n",")"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"fn639Efzrfwe"},"outputs":[],"source":["#@markdown ## Start the training\n","\n","#@markdown If you want to use of different configuration file or use an existing one, complete the following field. **Leave this field empty if you want to use the config file obtained during the preprocessing.**\n","\n","custom_config_path = \"\" #@param {type:\"string\"}\n","\n","if len(custom_config_path) > 0:\n"," config_path = custom_config_path\n","\n","print(\"We will the following configuration file for training:\", config_path)\n","builder = train(config=config_path)"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"4eKiLPQnsokH"},"outputs":[],"source":["#@markdown ## Make predictions with your trained model\n","\n","#@markdown **Path to the model folder: leave it empty if you want to use the model obtained during the previous training step.** This should look like `/content/gdrive/MyDrive/biom3d/logs/20230602-162331-unet_default_fold0`.\n","\n","custom_log_path = \"\" #@param {type:\"string\"}\n","\n","if len(custom_log_path) > 0:\n"," log=custom_log_path\n","else:\n"," assert 'builder' in locals().keys(), \"No existing model folder found. Please complete the `custom_log_path` field or train a model.\"\n"," log = builder.base_dir\n","\n","#@markdown **Prediction input directory:**\n","pred_dir_in = \"\" #@param {type:\"string\"}\n","\n","#@markdown **Prediction output directory:** (where the prediction masks will be stored)\n","pred_dir_out = \"\"#@param {type:\"string\"}\n","\n","dir_out = pred(\n"," log=log,\n"," dir_in=pred_dir_in,\n"," dir_out=pred_dir_out\n"," )\n"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"8l8FPDb_13iI"},"outputs":[],"source":["#@markdown ## Evaluate your model on a test set\n","\n","#@markdown **Prediction output directory:** where the previous prediction masks have been stored and the number of classes in your images. Leave the default values if you want to use the previous prediction path.\n","new_pred_dir_out = \"\"#@param {type:\"string\"}\n","new_num_classes = 0 #@param {type:\"number\"}\n","\n","if new_num_classes == 0:\n"," assert 'num_classes' in locals().keys(), \"Number of classes equal to zero and the previous number of classes does not exist. Please provide one.\"\n"," new_num_classes = num_classes\n","\n","if len(new_pred_dir_out) == 0:\n"," assert len(dir_out) > 0, \"Prediction path seems to be empty and no previous path detected.\"\n"," new_pred_dir_out = dir_out\n","\n","#@markdown **Path to test masks** The test masks must correspond to the predictions.\n","test_msk_dir = \"\" #@param {type:\"string\"}\n","\n","eval(\n"," dir_lab=test_msk_dir,\n"," dir_out=new_pred_dir_out,\n"," num_classes=new_num_classes\n",")"]}],"metadata":{"accelerator":"GPU","colab":{"authorship_tag":"ABX9TyPrfgBkQIScxLMN00pfvqBQ","gpuType":"T4","provenance":[]},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"name":"python"}},"nbformat":4,"nbformat_minor":0} diff --git a/docs/quick_run_gui.md b/docs/quick_run_gui.md index 8c9ee08..939b30a 100644 --- a/docs/quick_run_gui.md +++ b/docs/quick_run_gui.md @@ -18,7 +18,11 @@ Biom3d comes with 2 modes: local or remote. 'Local' means that the computation w If you have installed biom3d with the local version simply click on the 'Start locally' button to start, you can choose a path to store your files in the field over the button, by default, the files are stored in the directory where biom3d have been launched. +<<<<<<< HEAD +If you have installed biom3d with the remote version, you must then complete the required fields. The first one is the IP address of your remote computer (where the API of biom3d is installed). The second and third one is your user name and password to connect to the remote computer, the forth one is the path to your virtual environment (if you don't have a virtual environment leave it empty ). +======= If you have installed biom3d with the remote version, you must then complete the required fields. The first one is the IP address of your remote computer (where the API of biom3d is installed). The second and third one is your user name and password to connect to the remote computer, the forth one is the path to your virtual environment (if you don't have a virtual environment leave it empty). +>>>>>>> f562750641a74b960ed698e427ed811313552d0e ## Preprocess & Train diff --git a/logo_biom3d_crop.ico b/logo_biom3d_crop.ico new file mode 100644 index 0000000..c6291ab Binary files /dev/null and b/logo_biom3d_crop.ico differ diff --git a/pyproject.toml b/pyproject.toml index e0281be..f5db59a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta" [project] name = "biom3d" -version = "0.0.37" +version = "0.0.40" authors = [ {name="Guillaume Mougeot", email="guillaume.mougeot@laposte.net"}, ] @@ -17,7 +17,7 @@ classifiers = [ "Operating System :: OS Independent", ] dependencies = [ - "torch>1.10.0", + "torch", "tqdm>=4.62.3", "scikit-image>=0.14", "scipy>=1.9.1", diff --git a/src/biom3d/auto_config.py b/src/biom3d/auto_config.py index 44e5fd2..61daf36 100644 --- a/src/biom3d/auto_config.py +++ b/src/biom3d/auto_config.py @@ -38,6 +38,10 @@ def compute_median(path, return_spacing=False): for i in range(len(path_imgs)): img,metadata = adaptive_imread(path_imgs[i]) + # Check if the image is 2D (has two dimensions) + if len(img.shape) == 2: + # Add a third dimension with size 1 to make it 3D + img = np.expand_dims(img, axis=0) spacing = None if not 'spacing' in metadata.keys() else metadata['spacing'] assert len(img.shape)>0, "[Error] Wrong image image." @@ -91,6 +95,10 @@ def data_fingerprint(img_dir, msk_dir=None, num_samples=10000): for i in range(len(path_imgs)): img,metadata = adaptive_imread(path_imgs[i]) + # Check if the image is 2D (has two dimensions) + if len(img.shape) == 2: + # Add a third dimension with size 1 to make it 3D + img = np.expand_dims(img, axis=0) spacing = None if not 'spacing' in metadata.keys() else metadata['spacing'] # store the size @@ -103,6 +111,10 @@ def data_fingerprint(img_dir, msk_dir=None, num_samples=10000): if msk_dir is not None: # read msk msk,_ = adaptive_imread(path_msks[i]) + # Check if the image is 2D (has two dimensions) + if len(msk.shape) == 2: + # Add a third dimension with size 1 to make it 3D + msk = np.expand_dims(msk, axis=0) # extract only useful voxels img = img[msk > 0] @@ -151,7 +163,7 @@ def find_patch_pool_batch(dims, max_dims=(128,128,128), max_pool=5, epsilon=1e-3 batch: numpy.ndarray Batch size. """ - # transform tuples into arrays + # transform tuples into arrays assert len(dims)==3 or len(dims)==4, print("Dims has not the correct number of dimensions: len(dims)=", len(dims)) if len(dims)==4: dims=dims[1:] @@ -279,15 +291,25 @@ def get_aug_patch(patch_size): # ---------------------------------------------------------------------------- # Display +def parameters_return(patch, pool, batch, config_path): + """ + Displays the provided parameters. + """ + print(batch) + print(patch) + print(get_aug_patch(patch)) + print(pool) + print(config_path) def display_info(patch, pool, batch): """Print in terminal the patch size, the number of pooling and the batch size. """ + print("*"*20,"YOU CAN COPY AND PASTE THE FOLLOWING LINES INSIDE THE CONFIG FILE", "*"*20) print("BATCH_SIZE =", batch) print("PATCH_SIZE =", list(patch)) aug_patch = get_aug_patch(patch) - print("AUG_PATCH_SIZE =",list(aug_patch)) + print("AUG_PATCH_SIZE =",list(aug_patch)) print("NUM_POOLS =", list(pool)) def auto_config(img_dir=None, median=None, max_dims=(128,128,128), max_batch=16, min_batch=2): @@ -342,6 +364,8 @@ def auto_config(img_dir=None, median=None, max_dims=(128,128,128), max_batch=16, help="(default=\'configs/\') Configuration folder to save the auto-configuration.") parser.add_argument("--base_config", type=str, default=None, help="(default=None) Optional. Path to an existing configuration file which will be updated with the preprocessed values.") + parser.add_argument("--remote", default=False, dest='remote', + help="Use this arg when using remote autoconfing only") args = parser.parse_args() median = compute_median(path=args.img_dir, return_spacing=args.spacing) @@ -349,30 +373,33 @@ def auto_config(img_dir=None, median=None, max_dims=(128,128,128), max_batch=16, if args.spacing: median_spacing = median[1] median = median[0] + else: + median_spacing = None patch, pool, batch = find_patch_pool_batch(dims=median, max_dims=(args.max_dim, args.max_dim, args.max_dim)) aug_patch = np.array(patch)+2**(np.array(pool)+1) - display_info(patch, pool, batch) - - if args.spacing:print("MEDIAN_SPACING =",list(median_spacing)) - if args.median:print("MEDIAN =", list(median)) - - if args.save_config: + + if args.remote or args.save_config: try: from biom3d.utils import save_python_config config_path = save_python_config( config_dir=args.config_dir, base_config=args.base_config, - BATCH_SIZE=batch, AUG_PATCH_SIZE=aug_patch, PATCH_SIZE=patch, NUM_POOLS=pool, MEDIAN_SPACING=median_spacing, ) + parameters_return(patch, pool, batch, config_path, median_spacing) except: - print("[Error] Import error. Biom3d must be installed if you want to save your configuration. Another solution is to config the function function in biom3d.utils here...") + print("[Error] Import error. Biom3d must be installed if you want to save your configuration. Another solution is to config the function in biom3d.utils here...") raise ImportError + else : + display_info(patch, pool, batch) + if args.spacing:print("MEDIAN_SPACING =",list(median_spacing)) + if args.median:print("MEDIAN =", list(median)) + -# ---------------------------------------------------------------------------- \ No newline at end of file +# ---------------------------------------------------------------------------- diff --git a/src/biom3d/builder.py b/src/biom3d/builder.py index 7ef4441..a657513 100644 --- a/src/biom3d/builder.py +++ b/src/biom3d/builder.py @@ -695,7 +695,7 @@ def run_prediction_single(self, img_path=None, img=None, img_meta=None, return_l **img_meta) # all img_meta should be equal as we use the same preprocessors else: # single model prediction - img, img_meta = read_config(self.config.PREPROCESSOR, register.preprocessors, img=img, img_meta=img_meta) + img, img_meta = read_config(self.config.PREPROCESSOR, register.preprocessors, img=img, img_meta=img_meta, is_2d= self.config.IS_2D) print("Preprocessed shape:", img.shape) diff --git a/src/biom3d/config_default.py b/src/biom3d/config_default.py index 00edbf9..ee7bdbb 100644 --- a/src/biom3d/config_default.py +++ b/src/biom3d/config_default.py @@ -64,7 +64,7 @@ def __delattr__(self, name): del self[name] #--------------------------------------------------------------------------- # Auto-config builder-parameters # PASTE AUTO-CONFIG RESULTS HERE - +IS_2D = False # batch size BATCH_SIZE = 2 diff --git a/src/biom3d/datasets/semseg_batchgen.py b/src/biom3d/datasets/semseg_batchgen.py index dc19e53..79eeaf0 100644 --- a/src/biom3d/datasets/semseg_batchgen.py +++ b/src/biom3d/datasets/semseg_batchgen.py @@ -76,6 +76,27 @@ def centered_crop(img, msk, center, crop_shape, margin=np.zeros(3)): def located_crop(img, msk, location, crop_shape, margin=np.zeros(3)): """Do a crop, forcing the location voxel to be located in the crop. + + Parameters + ---------- + img : ndarray + Image data. + msk : ndarray + Mask data. + location : array_like + Specific voxel location to include in the crop. + crop_shape : array_like + Shape of the crop. + margin : array_like, optional + Margin around the location. + + Returns + ------- + crop_img : ndarray + Cropped image data, containing the specified location voxel within the crop. + crop_msk : ndarray + Cropped mask data, corresponding to the cropped image region. + """ img_shape = np.array(img.shape)[1:] location = np.array(location) @@ -96,6 +117,29 @@ def located_crop(img, msk, location, crop_shape, margin=np.zeros(3)): def foreground_crop(img, msk, final_size, fg_margin, fg=None, use_softmax=True): """Do a foreground crop. + + Parameters + ---------- + img : ndarray + Image data. + msk : ndarray + Mask data. + final_size : array_like + Final size of the cropped image and mask. + fg_margin : array_like + Margin around the foreground location. + fg : dict, optional + Foreground information. + use_softmax : bool, optional + If True, assumes softmax activation. + + Returns + ------- + img : ndarray + Cropped image data, focused on the foreground region. + msk : ndarray + Cropped mask data, corresponding to the cropped image region. + """ if fg is not None: locations = fg[random.choice(list(fg.keys()))] @@ -119,6 +163,23 @@ def foreground_crop(img, msk, final_size, fg_margin, fg=None, use_softmax=True): def random_crop(img, msk, crop_shape): """ randomly crop a portion of size prop of the original image size. + + Parameters + ---------- + img : ndarray + Image data. + msk : ndarray + Mask data. + crop_shape : array_like + Shape of the crop. + + Returns + ------- + crop_img : ndarray + Cropped image data. + crop_msk : ndarray + Cropped mask data. + """ img_shape = np.array(img.shape)[1:] assert len(img_shape)==len(crop_shape),"[Error] Not the same dimensions! Image shape {}, Crop shape {}".format(img_shape, crop_shape) @@ -135,6 +196,20 @@ def random_crop(img, msk, crop_shape): def centered_pad(img, final_size, msk=None): """ centered pad an img and msk to fit the final_size + + Parameters + ---------- + img : ndarray + Image data. + final_size : array_like + Final size after padding. + msk : ndarray, optional + Mask data. + + Returns + ------- + tuple or ndarray + Padded image and mask, or only the image if mask is None. """ final_size = np.array(final_size) img_shape = np.array(img.shape[1:]) @@ -156,6 +231,31 @@ def centered_pad(img, final_size, msk=None): def random_crop_pad(img, msk, final_size, fg_rate=0.33, fg_margin=np.zeros(3), fg=None, use_softmax=True): """ random crop and pad if needed. + + Parameters + ---------- + img : ndarray + Image data. + msk : ndarray + Mask data. + final_size : array_like + Final size after cropping and padding. + fg_rate : float, optional + Probability of focusing the crop on the foreground. + fg_margin : array_like, optional + Margin around the foreground location. + fg : dict, optional + Foreground information. + use_softmax : bool, optional + If True, assumes softmax activation; otherwise sigmoid is used. + + Returns + ------- + img : ndarray + Cropped and padded image data. + msk : ndarray + Cropped and padded mask data. + """ if type(img)==list: # then batch mode imgs, msks = [], [] @@ -179,6 +279,20 @@ def random_crop_pad(img, msk, final_size, fg_rate=0.33, fg_margin=np.zeros(3), f return img, msk class RandomCropAndPadTransform(AbstractTransform): + """ + BatchGenerator transform for random cropping and padding. + + Parameters + ---------- + crop_size : array_like + Size of the crop. + fg_rate : float, optional + Probability of focusing the crop on the foreground. + data_key : str, optional + Key for the data in the data dictionary. + label_key : str, optional + Key for the label in the data dictionary. + """ def __init__(self, crop_size, fg_rate=0.33, data_key="data", label_key="seg"): self.data_key = data_key self.label_key = label_key @@ -601,6 +715,31 @@ def get_validation_transforms(patch_size: Union[np.ndarray, Tuple[int]], class BatchGenDataLoader(SlimDataLoaderBase): """ Similar as torchio.SubjectsDataset but can be use with an unlimited amount of steps. + + Parameters + ---------- + img_dir : str + Directory containing the images. + msk_dir : str + Directory containing the masks. + batch_size : int + Size of the batches. + nbof_steps : int + Number of steps per epoch. + fg_dir : str, optional + Directory containing foreground information. + folds_csv : str, optional + CSV file containing fold information for dataset splitting. + fold : int, optional + Current fold number for training/validation splitting. + val_split : float, optional + Proportion of data to be used for validation. + train : bool, optional + If True, use the dataset for training; otherwise, use it for validation. + load_data : bool, optional + If True, loads the entire dataset into computer memory. + num_threads_in_mt : int, optional + Number of threads in multi-threaded augmentation. """ def __init__( @@ -810,6 +949,39 @@ def configure_rotation_dummyDA_mirroring_and_inital_patch_size(patch_size): return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes class MTBatchGenDataLoader(MultiThreadedAugmenter): + """ + Multi-threaded data loader for efficient data augmentation and loading. + + Parameters + ---------- + img_dir : str + Directory containing the images. + msk_dir : str + Directory containing the masks. + patch_size : array_like + The size of the patches to be extracted. + batch_size : int + Size of the batches. + nbof_steps : int + Number of steps per epoch. + fg_dir : str, optional + Directory containing foreground information. + folds_csv : str, optional + CSV file containing fold information for dataset splitting. + fold : int, optional + Current fold number for training/validation splitting. + val_split : float, optional + Proportion of data to be used for validation. + train : bool, optional + If True, use the dataset for training; otherwise, use it for validation. + load_data : bool, optional + If True, loads the entire dataset into computer memory. + fg_rate : float, optional + Foreground rate for cropping. + num_threads_in_mt : int, optional + Number of threads in multi-threaded augmentation. + """ + def __init__( self, img_dir, diff --git a/src/biom3d/datasets/semseg_patch_fast.py b/src/biom3d/datasets/semseg_patch_fast.py index ba3f5c2..7e04203 100644 --- a/src/biom3d/datasets/semseg_patch_fast.py +++ b/src/biom3d/datasets/semseg_patch_fast.py @@ -20,6 +20,26 @@ def centered_crop(img, msk, center, crop_shape, margin=np.zeros(3)): """Do a crop, forcing the location voxel to be located in the center of the crop. + + Parameters + ---------- + img : + Image data. + msk : + Mask data. + center : + Center voxel location for cropping. + crop_shape : + Shape of the crop. + margin : + Margin around the center location. + + Returns + ------- + crop_img : ndarray + Cropped image data. + crop_msk : ndarray + Cropped mask data. """ img_shape = np.array(img.shape)[1:] center = np.array(center) @@ -49,6 +69,27 @@ def centered_crop(img, msk, center, crop_shape, margin=np.zeros(3)): def located_crop(img, msk, location, crop_shape, margin=np.zeros(3)): """Do a crop, forcing the location voxel to be located in the crop. + + Parameters + ---------- + img : ndarray + Image data. + msk : ndarray + Mask data. + location : + Specific voxel location to include in the crop. + crop_shape : + Shape of the crop. + margin : + Margin around the location. + + Returns + ------- + crop_img : ndarray + Cropped image data. + crop_msk : ndarray + Cropped mask data. + """ img_shape = np.array(img.shape)[1:] location = np.array(location) @@ -69,6 +110,29 @@ def located_crop(img, msk, location, crop_shape, margin=np.zeros(3)): def foreground_crop(img, msk, final_size, fg_margin, fg=None, use_softmax=True): """Do a foreground crop. + + Parameters + ---------- + img : ndarray + Image data. + msk : ndarray + Mask data. + final_size : + Final size of the cropped image and mask. + fg_margin : + Margin around the foreground location. + fg : dict + Foreground information. + use_softmax : bool, optional + If True, assumes softmax activation. + + Returns + ------- + img : ndarray + Cropped image data, focused on the foreground region. + msk : ndarray + Cropped mask data, corresponding to the cropped image region. + """ if fg is not None and len(list(fg.keys()))>0: locations = fg[random.choice(list(fg.keys()))] @@ -113,6 +177,25 @@ def centered_pad(img, final_size, msk=None): def random_crop(img, msk, crop_shape, force_in=True): """ randomly crop a portion of size prop of the original image size. + + Parameters + ---------- + img : ndarray + Image data. + msk : ndarray + Mask data. + crop_shape : + Shape of the crop. + force_in : bool, optional + If True, ensures the crop is fully within the image boundaries. + + Returns + ------- + crop_img : ndarray + Cropped image data. + crop_msk : ndarray + Cropped mask data. + """ img_shape = np.array(img.shape)[1:] assert len(img_shape)==len(crop_shape),"[Error] Not the same dimensions! Image shape {}, Crop shape {}".format(img_shape, crop_shape) @@ -139,6 +222,31 @@ def random_crop(img, msk, crop_shape, force_in=True): def random_crop_pad(img, msk, final_size, fg_rate=0.33, fg_margin=np.zeros(3), fg=None, use_softmax=True): """ random crop and pad if needed. + + Parameters + ---------- + img : ndarray + Image data. + msk : ndarray + Mask data. + final_size : + Final size of the image and mask after cropping and padding. + fg_rate : float, optional + Probability of focusing the crop on the foreground. + fg_margin : + Margin around the foreground location. + fg : dict, optional + Foreground information. + use_softmax : bool, optional + If True, assumes softmax activation; + + Returns + ------- + img : ndarray + Cropped and padded image data. + msk : ndarray + Cropped and padded mask data. + """ if type(img)==list: # then batch mode imgs, msks = [], [] @@ -164,8 +272,28 @@ def random_crop_pad(img, msk, final_size, fg_rate=0.33, fg_margin=np.zeros(3), f def random_crop_resize(img, msk, crop_scale, final_size, fg_rate=0.33, fg_margin=np.zeros(3)): """ random crop and resize if needed. - Args: - crop_scale: >=1 + + Parameters + ---------- + img : ndarray + Image data. + msk : ndarray + Mask data. + crop_scale : >1 + Scale factor for the crop size. + final_size : + Final size of the image and mask after cropping and resizing. + fg_rate : float, optional + Probability of focusing the crop on the foreground. + fg_margin : + Margin around the foreground location. + + Returns + ------- + img : ndarray + Cropped and resized image data. + msk : ndarray + Cropped and resized mask data. """ final_size = np.array(final_size) @@ -201,6 +329,21 @@ def random_crop_resize(img, msk, crop_scale, final_size, fg_rate=0.33, fg_margin #--------------------------------------------------------------------------- class LabelToLong: + """ + Transform to convert label data to long (integer) type. + + Parameters + ---------- + label_name : str + Name of the label to be transformed. + + Returns + ------- + subject : dict + Dictionary with the label data converted to long (integer) type. + + """ + def __init__(self, label_name): self.label_name = label_name @@ -214,6 +357,42 @@ def __call__(self, subject): class SemSeg3DPatchFast(Dataset): """ with DataLoader + Dataset class for semantic segmentation with 3D patches. Supports data augmentation and efficient loading. + + Parameters + ---------- + img_dir : str + Directory containing the image files. + msk_dir : str + Directory containing the mask files. + batch_size : int + Batch size for dataset sampling. + patch_size : nd.array + Size of the patches to be used. + nbof_steps : int + Number of steps (batches) per epoch. + folds_csv : str, optional + CSV file containing fold information for dataset splitting. + fold : int, optional + The current fold number for training/validation splitting. + val_split : float, optional + Proportion of data to be used for validation. + train : bool, optional + If True, use the dataset for training; otherwise, use it for validation. + use_aug : bool, optional + If True, apply data augmentation. + aug_patch_size : nd.array, optional + Patch size to use for augmented patches. + fg_dir : str, optional + Directory containing foreground information. + fg_rate : float, optional + Foreground rate, used to force foreground inclusion in patches. + crop_scale : float, optional + Scale factor for crop size during augmentation. + load_data : bool, optional + If True, load the entire dataset into memory. + use_softmax : bool, optional + If True, use softmax activation. """ def __init__( self, diff --git a/src/biom3d/datasets/semseg_torchio.py b/src/biom3d/datasets/semseg_torchio.py index f2d13a6..644ca83 100644 --- a/src/biom3d/datasets/semseg_torchio.py +++ b/src/biom3d/datasets/semseg_torchio.py @@ -52,6 +52,8 @@ def __init__( Used with the foreground rate. Name of the label image in the tio.Subject. use_softmax: boolean, default=True Used with the foreground rate to know if the background should be removed. + **kwargs : dict + Additional keyword arguments. """ super().__init__(**kwargs) patch_size_array = np.array(to_tuple(patch_shape, length=3)) @@ -173,6 +175,14 @@ def apply_transform(self, subject: Subject) -> Subject: # utilities to change variable type in label/mask class LabelToFloat: + """ + Transform to convert label data to float type. + + Parameters + ---------- + label_name : str + Name of the label to be transformed. + """ def __init__(self, label_name): self.label_name = label_name @@ -182,6 +192,14 @@ def __call__(self, subject): return subject class LabelToLong: + """ + Transform to convert label data to long (integer) type. + + Parameters + ---------- + label_name : str + Name of the label to be transformed. + """ def __init__(self, label_name): self.label_name = label_name @@ -191,6 +209,14 @@ def __call__(self, subject): return subject class LabelToBool: + """ + Transform to convert label data to boolean type. + + Parameters + ---------- + label_name : str + Name of the label to be transformed. + """ def __init__(self, label_name): self.label_name = label_name @@ -202,6 +228,19 @@ def __call__(self, subject): #--------------------------------------------------------------------------- def reader(x): + """ + Custom reader function for image data. + + Parameters + ---------- + x : str + Path to the image file. + + Returns + ------- + Tuple + Loaded image data and metadata (if any). + """ return adaptive_imread(str(x))[0], None #--------------------------------------------------------------------------- @@ -234,9 +273,38 @@ def __init__( """ Parameters ---------- - load_data : boolean, default=False - if True, loads the all dataset into computer memory (faster but more memory expensive). ONLY COMPATIBLE WITH .npy PREPROCESSED IMAGES + img_dir : str + Directory containing the image files. + msk_dir : str + Directory containing the mask files. + batch_size : int + Batch size for dataset sampling. + patch_size : nd.array + Size of the patches to be used. + nbof_steps : int + Number of steps (batches) per epoch. + fg_dir : str, optional + Directory containing foreground information. + folds_csv : str, optional + CSV file containing fold information for dataset splitting. + fold : int, default=0 + The current fold number for training/validation splitting. + val_split : float, default=0.25 + Proportion of data to be used for validation. + train : bool, default=True + If True, use the dataset for training; otherwise, use it for validation. + use_aug : bool, default=True + If True, apply data augmentation. + aug_patch_size : nd.array + Patch size to use for augmented patches. + fg_rate : float, default=0.33 + Foreground rate, used to force foreground inclusion in patches. + load_data : bool, default=False + If True, loads the all dataset into computer memory (faster but more memory expensive). ONLY COMPATIBLE WITH .npy PREPROCESSED IMAGES + use_softmax : bool, default=True + If True, use softmax activation; otherwise, sigmoid is used. """ + self.img_dir = img_dir self.msk_dir = msk_dir self.fg_dir = fg_dir diff --git a/src/biom3d/gui.py b/src/biom3d/gui.py index 222348e..b2e607c 100644 --- a/src/biom3d/gui.py +++ b/src/biom3d/gui.py @@ -42,6 +42,7 @@ from biom3d.utils import load_python_config from biom3d.train import train from biom3d.eval import eval + from biom3d.omero_downloader import download_object import torch except ImportError as e: @@ -591,7 +592,7 @@ def send_data(self): class ConfigFrame(ttk.LabelFrame): """ Load or auto configure training parameters """ - def __init__(self, train_folder_selection=None, *arg, **kw): + def __init__(self, train_folder_selection=None, omero=None, omero_dataset = None, get_use_omero = False, *arg, **kw): super(ConfigFrame, self).__init__(*arg, **kw) # widgets definitions @@ -629,6 +630,10 @@ def __init__(self, train_folder_selection=None, *arg, **kw): self.img_outdir = train_folder_selection.img_outdir self.msk_outdir = train_folder_selection.msk_outdir self.config_dir = train_folder_selection.config_dir + + self.get_use_omero = get_use_omero + self.omero = omero + self.omero_dataset = omero_dataset self.auto_config_finished = ttk.Label(self, text="") # Number of epochs @@ -828,7 +833,7 @@ def auto_config(self): msk_dir_train = "data/{}/msk_out".format(selected_dataset) fg_dir_train = "data/{}/fg_out".format(selected_dataset) # error management - if len(auto_config_results)!=10: + if len(auto_config_results) not in (10,15): print("[Error] Auto-config error:", auto_config_results) popupmsg("[Error] Auto-config error: "+ str(auto_config_results)) while True: @@ -837,6 +842,7 @@ def auto_config(self): break print(line, end="") + # TODO : rewrite this part # get auto config results reversed_auto_config_results = auto_config_results[::-1] @@ -863,14 +869,30 @@ def auto_config(self): local_config_dir = local_config_dir.replace("\\", "\\\\") local_logs_dir = local_logs_dir.replace("\\", "\\\\") - config_path=auto_config_preprocess(img_dir=self.img_outdir.get(), - msk_dir=self.msk_outdir.get(), + + if self.get_use_omero(): + + raw_dataset = "Dataset:"+self.omero_dataset.id_entry.get() + mask_dataset = "Dataset:"+self.omero_dataset.msk_id_entry.get() + print("Using OMERO ! ") + datasets, img_dir = download_object(hostname=self.omero.hostname.get(), username=self.omero.username.get(), password=self.omero.password.get(), target_dir = "data/to_train/", obj=raw_dataset ) + datasets_mask, msk_dir = download_object(hostname=self.omero.hostname.get(), username=self.omero.username.get(), password=self.omero.password.get(), target_dir = "data/to_train/", obj=mask_dataset ) + img_dir = os.path.join(img_dir, datasets[0].name) + msk_dir = os.path.join(msk_dir, datasets_mask[0].name) + + else : + img_dir=self.img_outdir.get() + msk_dir=self.msk_outdir.get() + + config_path=auto_config_preprocess(img_dir=img_dir, + msk_dir=msk_dir, desc=self.builder_name_entry.get(), num_classes=self.num_classes.get(), remove_bg=False, use_tif=False, config_dir=local_config_dir, logs_dir=local_logs_dir, base_config=None, + is_2d=is_2d.get(), ) # Read the config file @@ -928,18 +950,34 @@ class TrainTab(ttk.Frame): def __init__(self, *arg, **kw): super(TrainTab, self).__init__(*arg, **kw) global new_config_path - global sent_dataset + global is_2d + + # Omero Preprocessing + self.omero_dataset = OmeroPreprocessing(self, text="Selection of Omero datasets", padding=[10,10,10,10]) + self.omero_connection = Connect2Omero(self, text="Connection to Omero server", padding=[10,10,10,10]) + self.use_omero_preprocessing_state = IntVar(value=0) + self.use_omero_preprocessing_button = ttk.Checkbutton(self, text="Download Datasets from Omero ? ", command=self.display_omero_preprocessing, variable=self.use_omero_preprocessing_state) + # Omero Connection ? + self.use_omero = self.use_omero_preprocessing_state.get() + + self.folder_selection = TrainFolderSelection(master=self, text="Preprocess", padding=[10,10,10,10]) - self.config_selection = ConfigFrame(train_folder_selection=self.folder_selection, master=self, text="Training configuration", padding=[10,10,10,10]) + self.config_selection = ConfigFrame(train_folder_selection=self.folder_selection, omero = self.omero_connection, omero_dataset = self.omero_dataset, get_use_omero = self.get_use_omero ,master=self, text="Training configuration", padding=[10,10,10,10]) self.train_button = ttk.Button(self, text="Start", style="train_button.TLabel", width =29, command=self.train) self.plot_button = ttk.Button(self, text="Plot Learning Curves", style="train_button.TLabel", width =29, command=self.get_logs_plot) self.fine_tune_button = ttk.Button(self, text="Fine-tune", style="train_button.TLabel", width=29, command=self.train) self.train_done = ttk.Label(self, text="") + # config folder self.dataset_preprocessed_state = IntVar(value=0) self.use_conf_button = ttk.Checkbutton(self, text="Dataset is already preprocessed ? ", command=self.display_conf_finetuning, variable=self.dataset_preprocessed_state) + # 2d image ? + self.is_2d = IntVar(value=0) + self.is_2d_check_button = ttk.Checkbutton(self, text="Process 2D images ? ", variable=self.is_2d) + + is_2d = self.is_2d #Fine tuning self.fine_tune_state = IntVar(value=0) self.use_tune_button = ttk.Checkbutton(self, text="Use Fine-Tuning ? ", command=self.display_conf_finetuning ,variable=self.fine_tune_state) @@ -950,8 +988,10 @@ def __init__(self, *arg, **kw): # set default values of train folders with the ones used for preprocess tab if not REMOTE : self.use_conf_button.grid(column=0,row=0,sticky=(N,W,E), pady=3) + self.is_2d_check_button.grid(column=0,row=0,sticky=(N,E), pady=3) else : self.plot_button.grid(column=0, row=5, padx=15, ipady=4, pady= 2, sticky=(N)) - self.use_tune_button.grid(column=0,row=1,sticky=(N,W,E), ipady=5) + self.use_tune_button.grid(column=0,row=1,sticky=(N,W), ipady=5) + self.use_omero_preprocessing_button.grid(column=0,row=1,sticky=(N,E), ipady=2) self.folder_selection.grid(column=0,row=2,sticky=(N,W,E), pady=3) self.config_selection.grid(column=0,row=3,sticky=(N,W,E), pady=20) self.train_button.grid(column=0, row=4,padx=15, ipady=4, pady= 2, sticky=(N)) @@ -960,11 +1000,43 @@ def __init__(self, *arg, **kw): self.columnconfigure(0, weight=1) for i in range(5): self.rowconfigure(i, weight=1) + def get_use_omero(self): + # Return the current state of use_omero_preprocessing_state + return self.use_omero_preprocessing_state.get() + def display_omero_preprocessing(self): + # TODO : Deal with fine tuning + + if self.use_omero_preprocessing_state.get(): + # Remove folder selection img and msk directories + self.folder_selection.grid_remove() + self.folder_selection.label1.grid_remove() + self.folder_selection.img_outdir.grid_remove() + self.folder_selection.label2.grid_remove() + self.folder_selection.msk_outdir.grid_remove() + self.use_conf_button.grid_remove() + self.use_tune_button.grid_remove() + + + self.omero_connection.grid(column=0,row=1,sticky=(N,W,E), pady=3) + self.omero_dataset.grid(column=0,row=2,sticky=(S,W,E), pady=5) + + else : + self.folder_selection.grid(column=0,row=2,sticky=(N,W,E), pady=3) + self.omero_connection.grid_remove() + self.omero_dataset.grid_remove() + self.use_conf_button.grid(column=0,row=0,sticky=(N,W,E), pady=3) + # Folder selection + self.folder_selection.label1.grid(column=0,row=2, sticky=W, pady=7) + self.folder_selection.img_outdir.grid(column=0, row=3, sticky=(W,E)) + self.folder_selection.label2.grid(column=0,row=4, sticky=W, pady=7) + self.folder_selection.msk_outdir.grid(column=0,row=5, sticky=(W,E)) + self.use_tune_button.grid(column=0,row=1,sticky=(N,W), ipady=5) + def display_conf_finetuning(self): if not REMOTE: # All Cases possible - + self.use_omero_preprocessing_button.grid(column=0,row=1,sticky=(N,E), ipady=2) if self.dataset_preprocessed_state.get() and self.fine_tune_state.get(): self.display_one() elif self.dataset_preprocessed_state.get() and not self.fine_tune_state.get(): @@ -985,10 +1057,10 @@ def display_conf_finetuning(self): self.folder_selection.label3.grid_remove() # Add the buttons back - self.use_tune_button.grid(column=0,row=1,sticky=(N,W,E), pady=3) + self.use_tune_button.grid(column=0,row=1,sticky=(N,W), pady=3) self.train_button.grid(column=0, row=4,padx=15, ipady=4, pady= 2, sticky=(N)) self.train_done.grid(column=0, row=5, sticky=W) - + self.use_omero_preprocessing_button.grid(column=0,row=1,sticky=(N,E), ipady=2) # Folder selection self.folder_selection.label1.grid(column=0,row=2, sticky=W, pady=7) self.folder_selection.img_outdir.grid(column=0, row=3, sticky=(W,E)) @@ -1039,8 +1111,10 @@ def display_conf_finetuning(self): self.folder_selection.grid(column=0,row=2, pady= 2, sticky=(W,E)) self.config_selection.grid(column=0,row=3,sticky=(N,W,E), pady=20) self.train_button.grid(column=0, row=5,padx=15, ipady=4, pady= 2, sticky=(N)) + self.use_omero_preprocessing_button.grid(column=0,row=1,sticky=(N,E), ipady=2) def display_one(self): + self.use_omero_preprocessing_button.grid(column=0,row=1,sticky=(N,E), ipady=2) # Folder selection self.folder_selection.grid(column=0,row=2,sticky=(N,W,E), pady=3) # Fine tuning @@ -1065,6 +1139,7 @@ def display_one(self): self.folder_selection.msk_outdir.grid_remove() # Remove train button self.train_button.grid_remove() + self.use_omero_preprocessing_button.grid_remove() def display_two(self): """ @@ -1072,6 +1147,7 @@ def display_two(self): """ if not REMOTE : # place the new ones + self.use_omero_preprocessing_button.grid(column=0,row=1,sticky=(N,E), ipady=2) self.folder_selection.label3.grid(column=0, row=1, sticky=W, pady=5) self.config_selection.load_config_button.grid(column=0, columnspan=4,row=4 ,ipady=4, pady=2,) self.folder_selection.config_dir.grid(column=0,row=6, sticky=(W,E)) @@ -1080,6 +1156,7 @@ def display_two(self): self.train_button.grid(column=0, row=4,padx=15, ipady=4, pady= 2, sticky=(N)) self.train_done.grid(column=0, row=5, sticky=W) # Remove everything else + self.use_omero_preprocessing_button.grid_remove() self.config_selection.builder_name_label.grid_remove() self.config_selection.builder_name_entry.grid_remove() self.config_selection.num_classes_label.grid_remove() @@ -1091,8 +1168,8 @@ def display_two(self): self.folder_selection.msk_outdir.grid_remove() self.FineTuning.grid_remove() self.fine_tune_button.grid_remove() - def display_three(self): + self.use_omero_preprocessing_button.grid(column=0,row=1,sticky=(N,E), ipady=2) self.folder_selection.grid(column=0, row=2, sticky=(N,W,E), pady=3) self.FineTuning.grid(column=0, row=3, sticky=(N,W,E), pady=3) self.config_selection.grid(column=0, row=4, sticky=(N,W,E), pady=3) @@ -1326,6 +1403,25 @@ def train_nohup(): else : # run the training train(config=new_config_path) + if self.use_omero_preprocessing_state.get() : + logs_path = "./logs" # Use relative path + if not os.path.exists(logs_path): + print(f"Directory '{logs_path}' does not exist.") + else: + directories = [d for d in os.listdir(logs_path) if os.path.isdir(os.path.join(logs_path, d))] + if not directories: + print("No directories found in the logs path.") + else: + directories.sort(key=lambda d: os.path.getmtime(os.path.join(logs_path, d)), reverse=True) + last_folder = directories[0] + image_folder = os.path.join(logs_path, last_folder, "image") + + biom3d.omero_uploader.run(username=self.omero_connection.username_entry.get(), password=self.omero_connection.password_entry.get(), hostname=self.omero_connection.hostname_entry.get() , + project=int(self.omero_dataset.id_entry.get()), + path=image_folder, + attachment=last_folder, + is_pred=False, + ) popupmsg(" Training done ! ") class FineTuning(ttk.LabelFrame): @@ -1382,7 +1478,7 @@ def __init__(self, *arg, **kw): self.rowconfigure(i, weight=1) else: - self.data_dir = FileDialog(self, mode='folder', textEntry=os.path.join('data', 'to_pred')) + self.data_dir = FileDialog(self, mode='folder', textEntry=os.path.join('', '')) self.data_dir.grid(column=0, row=1, sticky=(W,E)) self.columnconfigure(0, weight=1) @@ -1439,21 +1535,52 @@ def __init__(self, *arg, **kw): self.columnconfigure(1, weight=5) for i in range(3): self.rowconfigure(i, weight=1) + +class OmeroPreprocessing(ttk.LabelFrame): + """ + Choose an input Dataset from OMERO for Predictions + """ + def __init__(self, *arg, **kw): + super(OmeroPreprocessing, self).__init__(*arg, **kw) + + self.label_id = ttk.Label(self, text="Raw Images Dataset ID:") + self.id = StringVar(value="") + self.id_entry = ttk.Entry(self, textvariable=self.id) + + self.msk_label_id = ttk.Label(self, text="Masks Dataset ID:") + self.msk_id = StringVar(value="") + self.msk_id_entry = ttk.Entry(self, textvariable=self.msk_id) + + self.label_id.grid(column=0, row=0, sticky=(W,E)) + self.id_entry.grid(column=1,row=0,sticky=(W,E)) + + self.msk_label_id.grid(column=0, row=1, sticky=(W,E)) + self.msk_id_entry.grid(column=1,row=1,sticky=(W,E)) + self.columnconfigure(0, weight=1) + self.columnconfigure(1, weight=5) + self.rowconfigure(0, weight=1) class OmeroDataset(ttk.LabelFrame): """ - Choose an input Dataset from OMERO for Predictions + Choose an input Dataset from OMERO for Training """ def __init__(self, *arg, **kw): super(OmeroDataset, self).__init__(*arg, **kw) self.label_id = ttk.Label(self, text="Input Dataset ID:") - self.id = StringVar(value="19699") + self.id = StringVar(value="") self.id_entry = ttk.Entry(self, textvariable=self.id) + + self.project_label_id = ttk.Label(self, text="Output Project ID:") + self.project_id = StringVar(value="") + self.project_id_entry = ttk.Entry(self, textvariable=self.project_id) + self.label_id.grid(column=0, row=0, sticky=(W,E)) self.id_entry.grid(column=1,row=0,sticky=(W,E)) + self.project_label_id.grid(column=0, row=1, sticky=(W,E)) + self.project_id_entry.grid(column=1,row=1,sticky=(W,E)) self.columnconfigure(0, weight=1) self.columnconfigure(1, weight=5) self.rowconfigure(0, weight=1) @@ -1489,7 +1616,7 @@ def __init__(self, *arg, **kw): else: ## build folder self.label1 = ttk.Label(self, text="Select the folder containing the build:", anchor="sw", background='white') - self.logs_dir = FileDialog(self, mode='folder', textEntry='logs/') + self.logs_dir = FileDialog(self, mode='folder', textEntry='') self.logs_dir.grid(column=0, row=0, sticky=(W,E)) self.columnconfigure(0, weight=1) @@ -1557,7 +1684,7 @@ def __init__(self,*arg, **kw): self.prediction_folder= FileDialog(self, mode='folder', textEntry="data/pred") self.upload_project_label= ttk.Label(self, text="Project ID : ") #change to project id and send dataset name - self.project_id= StringVar(value="12906") + self.project_id= StringVar(value="") self.upload_project_entry= ttk.Entry(self,textvariable=self.project_id) self.upload_dataset_label= ttk.Label(self, text="Select a name for your dataset : ") #change to project id and send dataset name @@ -1623,7 +1750,7 @@ def send_to_omero(self): # get the last folder modified/created _,stdout,stderr=REMOTE.exec_command("ls -td {}/data/pred/{}/*/ | head -1".format(MAIN_DIR, self.dataset_selected_omero() )) last_folder = stdout.readline().replace('\n','') - _,stdout,stderr=REMOTE.exec_command("source {}/bin/activate; cd {}; python -m biom3d.omero_uploader --username {} --password {} --hostname {} --project {} --path '{}' --dataset_name {}".format(VENV,MAIN_DIR, self.username_entry.get(), self.password_entry.get(), self.hostname_entry.get(), self.upload_project_entry.get(), last_folder,self.dataset_name_entry.get() )) + _,stdout,stderr=REMOTE.exec_command("source {}/bin/activate; cd {}; python -m biom3d.omero_uploader --username {} --password {} --hostname {} --project {} --path '{}' --dataset_name {} --is_pred".format(VENV,MAIN_DIR, self.username_entry.get(), self.password_entry.get(), self.hostname_entry.get(), self.upload_project_entry.get(), last_folder,self.dataset_name_entry.get() )) while True: line = stdout.readline() @@ -1802,6 +1929,8 @@ def __init__(self, *arg, **kw): self.rowconfigure(i, weight=1) def predict(self): + #TODO test with None + attachment_file = self.model_selection.logs_dir.get() # if use Omero then use Omero prediction if REMOTE : # To Filter objects in Prediction @@ -1860,19 +1989,22 @@ def predict(self): if self.use_omero_state.get(): + obj="Dataset"+":"+self.omero_dataset.id.get() if REMOTE: # TODO: below, still OS dependant # Run OMERO prediction - _, stdout, stderr = REMOTE.exec_command("source {}/bin/activate; cd {}; python -m biom3d.omero_pred --obj {} --log {} --username {} --password {} --hostname {} ".format(VENV, + _, stdout, stderr = REMOTE.exec_command("source {}/bin/activate; cd {}; python -m biom3d.omero_pred --obj {} --log {} --username {} --password {} --hostname {} --upload_id {} --attachment {} ".format(VENV, MAIN_DIR, obj, - MAIN_DIR+'/logs/'+self.model_selection.logs_dir.get(), + MAIN_DIR+'logs/'+self.model_selection.logs_dir.get(), self.omero_connection.username.get(), self.omero_connection.password.get(), - self.omero_connection.hostname.get() + self.omero_connection.hostname.get(), + self.omero_dataset.project_id.get(), + MAIN_DIR+'logs/' +attachment_file, )) while True: line = stdout.readline() @@ -1894,6 +2026,7 @@ def predict(self): if not os.path.isdir(target): os.makedirs(target, exist_ok=True) print("Downloading Omero dataset into", target) + # run OMERO prediction p=biom3d.omero_pred.run( obj=obj, @@ -1902,15 +2035,10 @@ def predict(self): dir_out=self.output_dir.data_dir.get(), user=self.omero_connection.username.get(), pwd=self.omero_connection.password.get(), - host=self.omero_connection.hostname.get() + host=self.omero_connection.hostname.get(), + upload_id=int(self.omero_dataset.project_id.get()), + attachment=attachment_file, ) - if self.send_to_omero_state.get(): - biom3d.omero_uploader.run(username=self.send_to_omero_connection.username.get(), - password=self.send_to_omero_connection.password.get(), - hostname=self.send_to_omero_connection.hostname.get(), - project=int(self.send_to_omero_connection.upload_project_entry.get()), - dataset_name=self.send_to_omero_connection.dataset_name_entry.get(), - path=p) else: # if not use Omero if REMOTE: @@ -1954,6 +2082,8 @@ def predict(self): hostname=self.send_to_omero_connection.hostname.get(), project=int(self.send_to_omero_connection.upload_project_entry.get()), dataset_name=self.send_to_omero_connection.dataset_name_entry.get(), + attachment=attachment_file, + is_pred=True, path=p) popupmsg("Prediction done !") @@ -1971,7 +2101,7 @@ def display_omero(self): self.omero_connection.grid(column=0,row=1,sticky=(W,E), pady=6) self.omero_dataset.grid(column=0,row=2,sticky=(W,E), pady=6) - + self.send_to_omero.grid_remove() else: # hide omero self.omero_connection.grid_remove() @@ -1979,7 +2109,7 @@ def display_omero(self): # reset the input dir self.input_dir.grid(column=0,row=1,sticky=(W,E)) - + self.send_to_omero.grid(column=0,row=4,sticky=(W,E), pady=6) def display_send_to_omero(self): """ For displaying and hiding OMERO tab @@ -2085,7 +2215,7 @@ def __init__(self, *arg, **kw): self.password_entry = ttk.Entry(self, textvariable=self.password, show='*') self.main_dir_label = ttk.Label(self, text='Folder of Biom3d repository on remote server:') - self.main_dir = StringVar(value="/home/biome/biom3d") + self.main_dir = StringVar(value="/home/biome/") self.main_dir_entry = ttk.Entry(self, textvariable=self.main_dir) self.venv_label = ttk.Label(self, text='(Optional) Name of the virtual environment on remote server:') diff --git a/src/biom3d/models/encoder_efficientnet3d.py b/src/biom3d/models/encoder_efficientnet3d.py new file mode 100644 index 0000000..fcf7f30 --- /dev/null +++ b/src/biom3d/models/encoder_efficientnet3d.py @@ -0,0 +1,486 @@ +#--------------------------------------------------------------------------- +# 3D efficient net adapted from: +# https://github.com/shijianjian/EfficientNet-PyTorch-3D +# +# usage: +# model = EfficientNet3D.from_name("efficientnet-b1", override_params={'include_top': False}, in_channels=1) +# model.cuda() + +# lists of pyramid layers: +# { +# 0: ['_conv_stem'], # 100 +# 1: ['_blocks', '1', '_bn0'], # 50 +# 2: ['_blocks', '3', '_bn0'], # 25 +# 3: ['_blocks', '5', '_bn0'], # 12 +# 4: ['_blocks', '11', '_bn0'], # 6 +# 5: ['_bn1'] # 3 +# } +#--------------------------------------------------------------------------- + +""" +This file contains helper functions for building the model and for loading model parameters. +These helper functions are built to mirror those in the official TensorFlow implementation. +""" + +import re +import math +import collections +import torch +from torch import nn +from torch.nn import functional as F +import numpy as np + +######################################################################## +############### HELPERS FUNCTIONS FOR MODEL ARCHITECTURE ############### +######################################################################## + + +# Parameters for the entire model (stem, all blocks, and head) +GlobalParams = collections.namedtuple('GlobalParams', [ + 'batch_norm_momentum', 'batch_norm_epsilon', 'dropout_rate', + 'num_classes', 'width_coefficient', 'depth_coefficient', + 'depth_divisor', 'min_depth', 'drop_connect_rate', 'image_size', 'include_top']) + +# Parameters for an individual model block +BlockArgs = collections.namedtuple('BlockArgs', [ + 'kernel_size', 'num_repeat', 'input_filters', 'output_filters', + 'expand_ratio', 'id_skip', 'stride', 'se_ratio']) + +# Change namedtuple defaults +GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields) +BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields) + + +class SwishImplementation(torch.autograd.Function): + @staticmethod + def forward(ctx, i): + result = i * torch.sigmoid(i) + ctx.save_for_backward(i) + return result + + @staticmethod + def backward(ctx, grad_output): + i = ctx.saved_variables[0] + sigmoid_i = torch.sigmoid(i) + return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i))) + + +class MemoryEfficientSwish(nn.Module): + def forward(self, x): + return SwishImplementation.apply(x) + +class Swish(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +def round_filters(filters, global_params): + """ Calculate and round number of filters based on depth multiplier. """ + multiplier = global_params.width_coefficient + if not multiplier: + return filters + divisor = global_params.depth_divisor + min_depth = global_params.min_depth + filters *= multiplier + min_depth = min_depth or divisor + new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor) + if new_filters < 0.9 * filters: # prevent rounding by more than 10% + new_filters += divisor + return int(new_filters) + + +def round_repeats(repeats, global_params): + """ Round number of filters based on depth multiplier. """ + multiplier = global_params.depth_coefficient + if not multiplier: + return repeats + return int(math.ceil(multiplier * repeats)) + + +def drop_connect(inputs, p, training): + """ Drop connect. """ + if not training: return inputs + batch_size = inputs.shape[0] + keep_prob = 1 - p + random_tensor = keep_prob + random_tensor += torch.rand([batch_size, 1, 1, 1, 1], dtype=inputs.dtype, device=inputs.device) + binary_tensor = torch.floor(random_tensor) + output = inputs / keep_prob * binary_tensor + return output + +######################################################################## +############## HELPERS FUNCTIONS FOR LOADING MODEL PARAMS ############## +######################################################################## + + +def efficientnet_params(model_name): + """ Map EfficientNet model name to parameter coefficients. """ + params_dict = { + # Coefficients: width,depth,res,dropout + 'efficientnet-b0': (1.0, 1.0, 224, 0.2), + 'efficientnet-b1': (1.0, 1.1, 240, 0.2), + 'efficientnet-b2': (1.1, 1.2, 260, 0.3), + 'efficientnet-b3': (1.2, 1.4, 300, 0.3), + 'efficientnet-b4': (1.4, 1.8, 380, 0.4), + 'efficientnet-b5': (1.6, 2.2, 456, 0.4), + 'efficientnet-b6': (1.8, 2.6, 528, 0.5), + 'efficientnet-b7': (2.0, 3.1, 600, 0.5), + 'efficientnet-b8': (2.2, 3.6, 672, 0.5), + 'efficientnet-l2': (4.3, 5.3, 800, 0.5), + } + return params_dict[model_name] + + +class BlockDecoder(object): + """ Block Decoder for readability, straight from the official TensorFlow repository """ + + @staticmethod + def _decode_block_string(block_string): + """ Gets a block through a string notation of arguments. """ + assert isinstance(block_string, str) + + ops = block_string.split('_') + options = {} + for op in ops: + splits = re.split(r'(\d.*)', op) + if len(splits) >= 2: + key, value = splits[:2] + options[key] = value + + # Check stride + assert (('s' in options and len(options['s']) == 1) or + (len(options['s']) == 3 and options['s'][0] == options['s'][1] == options['s'][2])) + + return BlockArgs( + kernel_size=int(options['k']), + num_repeat=int(options['r']), + input_filters=int(options['i']), + output_filters=int(options['o']), + expand_ratio=int(options['e']), + id_skip=('noskip' not in block_string), + se_ratio=float(options['se']) if 'se' in options else None, + stride=[int(options['s'][0])]) + + @staticmethod + def _encode_block_string(block): + """Encodes a block to a string.""" + args = [ + 'r%d' % block.num_repeat, + 'k%d' % block.kernel_size, + 's%d%d%d' % (block.strides[0], block.strides[1], block.strides[2]), + 'e%s' % block.expand_ratio, + 'i%d' % block.input_filters, + 'o%d' % block.output_filters + ] + if 0 < block.se_ratio <= 1: + args.append('se%s' % block.se_ratio) + if block.id_skip is False: + args.append('noskip') + return '_'.join(args) + + @staticmethod + def decode(string_list): + """ + Decodes a list of string notations to specify blocks inside the network. + + :param string_list: a list of strings, each string is a notation of block + :return: a list of BlockArgs namedtuples of block args + """ + assert isinstance(string_list, list) + blocks_args = [] + for block_string in string_list: + blocks_args.append(BlockDecoder._decode_block_string(block_string)) + return blocks_args + + @staticmethod + def encode(blocks_args): + """ + Encodes a list of BlockArgs to a list of strings. + + :param blocks_args: a list of BlockArgs namedtuples of block args + :return: a list of strings, each string is a notation of block + """ + block_strings = [] + for block in blocks_args: + block_strings.append(BlockDecoder._encode_block_string(block)) + return block_strings + + +def efficientnet3d(width_coefficient=None, depth_coefficient=None, dropout_rate=0.2, + drop_connect_rate=0.2, image_size=None, num_classes=1000, include_top=True): + """ Creates a efficientnet model. """ + + blocks_args = [ + 'r1_k3_s222_e1_i32_o16_se0.25', 'r2_k3_s222_e6_i16_o24_se0.25', + 'r2_k5_s222_e6_i24_o40_se0.25', 'r3_k3_s222_e6_i40_o80_se0.25', + 'r3_k5_s111_e6_i80_o112_se0.25', 'r4_k5_s222_e6_i112_o192_se0.25', + 'r1_k3_s111_e6_i192_o320_se0.25', + ] + blocks_args = BlockDecoder.decode(blocks_args) + + global_params = GlobalParams( + batch_norm_momentum=0.99, + batch_norm_epsilon=1e-3, + dropout_rate=dropout_rate, + drop_connect_rate=drop_connect_rate, + # data_format='channels_last', # removed, this is always true in PyTorch + num_classes=num_classes, + width_coefficient=width_coefficient, + depth_coefficient=depth_coefficient, + depth_divisor=8, + min_depth=None, + image_size=image_size, + include_top=include_top, + ) + + return blocks_args, global_params + + +def get_model_params(model_name, override_params): + """ Get the block args and global params for a given model """ + if model_name.startswith('efficientnet'): + w, d, s, p = efficientnet_params(model_name) + # note: all models have drop connect rate = 0.2 + blocks_args, global_params = efficientnet3d( + width_coefficient=w, depth_coefficient=d, dropout_rate=p, image_size=s) + else: + raise NotImplementedError('model name is not pre-defined: %s' % model_name) + if override_params: + # ValueError will be raised here if override_params has fields not included in global_params. + global_params = global_params._replace(**override_params) + return blocks_args, global_params + +class MBConvBlock3D(nn.Module): + """ + Mobile Inverted Residual Bottleneck Block + + Args: + block_args (namedtuple): BlockArgs, see above + global_params (namedtuple): GlobalParam, see above + + Attributes: + has_se (bool): Whether the block contains a Squeeze and Excitation layer. + """ + + def __init__(self, block_args, global_params): + # print("block_arg", block_args) + # print("global_params", global_params) + super().__init__() + self._block_args = block_args + self._bn_mom = 1 - global_params.batch_norm_momentum + self._bn_eps = global_params.batch_norm_epsilon + self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1) + self.id_skip = block_args.id_skip # skip connection and drop connect + + # Get static or dynamic convolution depending on image size + # Conv3d = get_same_padding_conv3d(image_size=global_params.image_size) + Conv3d = nn.Conv3d + + # Expansion phase + inp = self._block_args.input_filters # number of input channels + oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels + if self._block_args.expand_ratio != 1: + self._expand_conv = Conv3d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False) + # self._bn0 = nn.BatchNorm3d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) + self._bn0 = nn.InstanceNorm3d(num_features=oup) + + # Depthwise convolution phase + k = self._block_args.kernel_size + s = self._block_args.stride + self._depthwise_conv = Conv3d( + in_channels=oup, out_channels=oup, groups=oup, # groups makes it depthwise + kernel_size=k, stride=s, bias=False, padding=np.array(k)//2) + # self._bn1 = nn.BatchNorm3d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) + self._bn1 = nn.InstanceNorm3d(num_features=oup) + + # Squeeze and Excitation layer, if desired + if self.has_se: + num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio)) + self._se_reduce = Conv3d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1) + self._se_expand = Conv3d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1) + + # Output phase + final_oup = self._block_args.output_filters + self._project_conv = Conv3d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False) + # self._bn2 = nn.BatchNorm3d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps) + self._bn2 = nn.InstanceNorm3d(num_features=final_oup) + self._swish = MemoryEfficientSwish() + + def forward(self, inputs, drop_connect_rate=None): + """ + :param inputs: input tensor + :param drop_connect_rate: drop connect rate (float, between 0 and 1) + :return: output of block + """ + + # Expansion and Depthwise Convolution + x = inputs + if self._block_args.expand_ratio != 1: + x = self._swish(self._bn0(self._expand_conv(inputs))) + x = self._swish(self._bn1(self._depthwise_conv(x))) + + # Squeeze and Excitation + if self.has_se: + x_squeezed = F.adaptive_avg_pool3d(x, 1) + x_squeezed = self._se_expand(self._swish(self._se_reduce(x_squeezed))) + x = torch.sigmoid(x_squeezed) * x + + x = self._bn2(self._project_conv(x)) + + # Skip connection and drop connect + input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters + if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters: + if drop_connect_rate: + x = drop_connect(x, p=drop_connect_rate, training=self.training) + x = x + inputs # skip connection + return x + + def set_swish(self, memory_efficient=True): + """Sets swish function as memory efficient (for training) or standard (for export)""" + self._swish = MemoryEfficientSwish() if memory_efficient else Swish() + + +class EfficientNet3D(nn.Module): + """ + An EfficientNet model. Most easily loaded with the .from_name or .from_pretrained methods + + Args: + blocks_args (list): A list of BlockArgs to construct blocks + global_params (namedtuple): A set of GlobalParams shared between blocks + + Example: + model = EfficientNet3D.from_pretrained('efficientnet-b0') + + """ + + def __init__(self, blocks_args=None, global_params=None, in_channels=3, num_pools=[5,5,5], first_stride=[1,1,1]): + super().__init__() + assert isinstance(blocks_args, list), 'blocks_args should be a list' + assert len(blocks_args) > 0, 'block args must be greater than 0' + self._global_params = global_params + self._blocks_args = blocks_args + + # Get static or dynamic convolution depending on image size + # Conv3d = get_same_padding_conv3d(image_size=global_params.image_size) + Conv3d = nn.Conv3d + + # Batch norm parameters + bn_mom = 1 - self._global_params.batch_norm_momentum + bn_eps = self._global_params.batch_norm_epsilon + + # Stem + out_channels = round_filters(32, self._global_params) # number of output channels + self._conv_stem = Conv3d(in_channels, out_channels, kernel_size=3, stride=first_stride, bias=False, padding=1) + # self._bn0 = nn.BatchNorm3d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) + self._bn0 = nn.InstanceNorm3d(num_features=out_channels) + + # Set adaptive number of pools + # for example: convert [3,5,5] into [[1 1 1],[1 2 2],[2 2 2],[2 2 2],[2 2 2],[1 2 2]] + max_pool = max(num_pools) + strides = [] + for i in range(len(num_pools)): + st = np.ones(max_pool) + num_zeros = max_pool-num_pools[i] + for j in range(num_zeros): + st[j]=0 + st=np.roll(st,-num_zeros//2) + strides += [st] + strides = np.array(strides).astype(int).T+1 + # kernels = (strides*3//2).tolist() + strides = strides.tolist() + + # Build blocks + self._blocks = nn.ModuleList([]) + crt_stride = 0 + for block_args in self._blocks_args: + + # Update block input and output filters based on depth multiplier. + block_args = block_args._replace( + input_filters=round_filters(block_args.input_filters, self._global_params), + output_filters=round_filters(block_args.output_filters, self._global_params), + num_repeat=round_repeats(block_args.num_repeat, self._global_params), + ) + + if np.greater(block_args.stride,1): + # block_args = block_args._replace(stride=strides[crt_stride],kernel_size=kernels[crt_stride]) + block_args = block_args._replace(stride=strides[crt_stride]) + crt_stride += 1 + + # The first block needs to take care of stride and filter size increase. + self._blocks.append(MBConvBlock3D(block_args, self._global_params)) + if block_args.num_repeat > 1: + block_args = block_args._replace(input_filters=block_args.output_filters, stride=1) + for _ in range(block_args.num_repeat - 1): + self._blocks.append(MBConvBlock3D(block_args, self._global_params)) + + # Head + in_channels = block_args.output_filters # output of final block + out_channels = round_filters(1280, self._global_params) + self._conv_head = Conv3d(in_channels, out_channels, kernel_size=1, bias=False) + self._bn1 = nn.BatchNorm3d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) + + # Final linear layer + self._avg_pooling = nn.AdaptiveAvgPool3d(1) + self._dropout = nn.Dropout(self._global_params.dropout_rate) + self._fc = nn.Linear(out_channels, self._global_params.num_classes) + self._swish = MemoryEfficientSwish() + # self.set_swish(memory_efficient=False) + + def set_swish(self, memory_efficient=True): + """Sets swish function as memory efficient (for training) or standard (for export)""" + self._swish = MemoryEfficientSwish() if memory_efficient else Swish() + for block in self._blocks: + block.set_swish(memory_efficient) + + + def extract_features(self, inputs): + """ Returns output of the final convolution layer """ + + # Stem + x = self._swish(self._bn0(self._conv_stem(inputs))) + + # Blocks + for idx, block in enumerate(self._blocks): + drop_connect_rate = self._global_params.drop_connect_rate + if drop_connect_rate: + drop_connect_rate *= float(idx) / len(self._blocks) + x = block(x, drop_connect_rate=drop_connect_rate) + + # Head + x = self._swish(self._bn1(self._conv_head(x))) + + return x + + def forward(self, inputs): + """ Calls extract_features to extract features, applies final linear layer, and returns logits. """ + bs = inputs.size(0) + # Convolution layers + x = self.extract_features(inputs) + + if self._global_params.include_top: + # Pooling and final linear layer + x = self._avg_pooling(x) + x = x.view(bs, -1) + x = self._dropout(x) + x = self._fc(x) + return x + + @classmethod + def from_name(cls, model_name, override_params=None, in_channels=3, **kwargs): + cls._check_model_name_is_valid(model_name) + blocks_args, global_params = get_model_params(model_name, override_params) + return cls(blocks_args, global_params, in_channels, **kwargs) + + @classmethod + def get_image_size(cls, model_name): + cls._check_model_name_is_valid(model_name) + _, _, res, _ = efficientnet_params(model_name) + return res + + @classmethod + def _check_model_name_is_valid(cls, model_name): + """ Validates model name. """ + valid_models = ['efficientnet-b'+str(i) for i in range(9)] + if model_name not in valid_models: + raise ValueError('model_name should be one of: ' + ', '.join(valid_models)) + diff --git a/src/biom3d/models/unet3d_eff.py b/src/biom3d/models/unet3d_eff.py new file mode 100644 index 0000000..e7a0467 --- /dev/null +++ b/src/biom3d/models/unet3d_eff.py @@ -0,0 +1,184 @@ +#--------------------------------------------------------------------------- +# 3D efficient net stolen from: +# https://github.com/shijianjian/EfficientNet-PyTorch-3D +# +# usage: +# model = EfficientNet3D.from_name("efficientnet-b1", override_params={'include_top': False}, in_channels=1) +# model.cuda() +#--------------------------------------------------------------------------- + +from biom3d.models.encoder_vgg import EncoderBlock, SmallEncoderBlock +from biom3d.models.decoder_vgg_deep import VGGDecoder +from biom3d.models.encoder_efficientnet3d import EfficientNet3D, efficientnet3d + +import torch +from torch import nn + +def get_layer(model, layer_names): + """ + get a layer from a model from a list of its module and submodules + e.g.: l = ['_blocks','0','_depthwise_conv'] + """ + for e in layer_names: + model = model._modules[e] + return model + +def get_pyramid(model, pyramid): + """ + return a list of layers from the model described by the dictionary called 'pyramid'. + e.g.: + pyramid = { + 0: ['_conv_stem'], # 100 + 1: ['_blocks', '1', '_bn0'], # 50 + 2: ['_blocks', '3', '_bn0'], # 25 + 3: ['_blocks', '5', '_bn0'], # 12 + 4: ['_blocks', '11', '_bn0'], # 6 + 5: ['_bn1'] # 3 + } + """ + layers = [] + for v in pyramid.values(): + layers += [get_layer(model, v)] + return layers + +def get_outfmaps(layer): + """ + return the depth of output feature map. + """ + if 'num_features' in layer.__dict__.keys(): + return layer.num_features + elif 'in_channels' in layer.__dict__.keys(): + return layer.in_channels + else: + print("[Error] layer is not standard, cannot extract output feature maps.") + return 0 + +#--------------------------------------------------------------------------- +# 3D UNet with the previous encoder and decoder + +class EffUNet(nn.Module): + def __init__( + self, + patch_size, + num_pools=[5,5,5], + num_classes=1, + factor=32, + encoder_ckpt = None, + model_ckpt = None, + use_deep=True, + in_planes = 1, + ): + super(EffUNet, self).__init__() + + pyramid={ # efficientnet b4 + 0: ['_bn0'], + 1: ['_blocks', '1', '_bn2'], + 2: ['_blocks', '5', '_bn2'], + 3: ['_blocks', '9', '_bn2'], + 4: ['_blocks', '21', '_bn2'], + 5: ['_blocks', '31', '_bn2'], + } + # pyramid={ # efficientnet b2 + # 0: ['_bn0'], + # 1: ['_blocks', '1', '_bn2'], + # 2: ['_blocks', '4', '_bn2'], + # 3: ['_blocks', '7', '_bn2'], + # 4: ['_blocks', '15', '_bn2'], + # 5: ['_blocks', '22', '_bn2'], + # } + blocks_args, global_params = efficientnet3d( + # width_coefficient=1.1, # efficientnet b2 + # depth_coefficient=1.2, + # dropout_rate=0.3, + width_coefficient=1.4, # efficientnet b4 + depth_coefficient=1.8, + dropout_rate=0.4, + drop_connect_rate=0.2, + image_size=patch_size, + include_top=False + ) + self.encoder = EfficientNet3D( + blocks_args, + global_params, + in_channels=in_planes, + num_pools=num_pools, + ) + + # load encoder if needed + if encoder_ckpt is not None: + print("Load encoder weights from", encoder_ckpt) + if torch.cuda.is_available(): + self.encoder.cuda() + ckpt = torch.load(encoder_ckpt) + # if 'last_layer.weight' in ckpt['model'].keys(): + # del ckpt['model']['last_layer.weight'] + if 'model' in ckpt.keys(): + # remove `module.` prefix + state_dict = {k.replace("module.", ""): v for k, v in ckpt['model'].items()} + # remove `0.` prefix induced by the sequential wrapper + state_dict = {k.replace("0.layers", "layers"): v for k, v in state_dict.items()} + print(self.encoder.load_state_dict(state_dict, strict=False)) + elif 'teacher' in ckpt.keys(): + # remove `module.` prefix + state_dict = {k.replace("module.", ""): v for k, v in ckpt['teacher'].items()} + # remove `backbone.` prefix induced by multicrop wrapper + state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} + print(self.encoder.load_state_dict(state_dict, strict=False)) + else: + print("[Warning] the following encoder couldn't be loaded, wrong key:", encoder_ckpt) + + self.pyramid = get_pyramid(self.encoder, pyramid) # only the first five elements of the list are used + + # hook the pyramid + self.down = {} + for i in range(len(self.pyramid)): + self.pyramid[i].register_forward_hook(self.get_activation(i)) + + self.decoder = VGGDecoder( + EncoderBlock, + # SmallEncoderBlock, + num_pools=num_pools, + num_classes=num_classes, + factor_e=[get_outfmaps(l) for l in self.pyramid][::-1], + factor_d=[get_outfmaps(l)*2 for l in self.pyramid][::-1][1:-1]+[factor], + use_deep=use_deep, + ) + + + if model_ckpt is not None: + print("Load model weights from", model_ckpt) + if torch.cuda.is_available(): + self.cuda() + ckpt = torch.load(model_ckpt) + if 'encoder.last_layer.weight' in ckpt['model'].keys(): + del ckpt['model']['encoder.last_layer.weight'] + self.load_state_dict(ckpt['model']) + + def freeze_encoder(self, freeze=True): + """ + freeze or unfreeze encoder model + """ + if freeze: + print("Freezing encoder weights...") + else: + print("Unfreezing encoder weights...") + for l in self.encoder.parameters(): + l.requires_grad = not freeze + + def unfreeze_encoder(self): + self.freeze_encoder(False) + + def get_activation(self, name): + def hook(model, input, output): + self.down[name] = output + # self.down += [output] # CAREFUL MEMORY LEAK! when using list make sure to empty it during the forward pass! + return hook + + def forward(self, x): # x is an image + self.encoder(x) + out = self.decoder(list(self.down.values())) + # del self.down # CAREFUL MEMORY LEAK! when using list make sure to empty it during the forward pass! + # self.down = [] + return out + +#--------------------------------------------------------------------------- diff --git a/src/biom3d/models/unet3d_vgg_deep.py b/src/biom3d/models/unet3d_vgg_deep.py index 2f8ddad..a9f93ed 100644 --- a/src/biom3d/models/unet3d_vgg_deep.py +++ b/src/biom3d/models/unet3d_vgg_deep.py @@ -12,6 +12,52 @@ # 3D UNet with the previous encoder and decoder class UNet(nn.Module): + """ + A 3D UNet architecture utilizing VGG-style encoder and decoder blocks for volumetric (3D) image segmentation. + + The UNet model is a convolutional neural network for fast and precise segmentation of images. + This implementation incorporates VGG blocks for encoding and decoding, allowing for deep feature extraction + and reconstruction, respectively. The model supports dynamic adjustment of pooling layers and class numbers, + along with optional deep decoder usage and weight initialization from pre-trained checkpoints. + + Parameters + ---------- + num_pools : list of int + A list of integers defining the number of pooling layers for each dimension of the input. Default is [5,5,5]. + num_classes : int + The number of classes for segmentation. Default is 1. + factor : int + The scaling factor for the number of channels in VGG blocks. Default is 32. + encoder_ckpt : str, optional + Path to a checkpoint file from which to load encoder weights. + model_ckpt : str, optional + Path to a checkpoint file from which to load the entire model's weights. + use_deep : bool + Flag to indicate whether to use a deep decoder. Default is True. + in_planes : int + The number of input channels. Default is 1. + flip_strides : bool + Flag to flip strides to match encoder and decoder dimensions. Useful for ensuring dimensionality alignment. + + Attributes + ---------- + encoder : VGGEncoder + The encoder part of the UNet, responsible for downscaling and feature extraction. + decoder : VGGDecoder + The decoder part of the UNet, responsible for upscaling and constructing the segmentation map. + + Methods + ------- + freeze_encoder(freeze=True) + Freezes or unfreezes the encoder's weights. + unfreeze_encoder() + Convenience method to unfreeze the encoder's weights. + load(model_ckpt) + Loads the model's weights from a specified checkpoint. + forward(x) + Defines the computation performed at every call. Applies the encoder and decoder on the input. + + """ def __init__( self, num_pools=[5,5,5], @@ -78,7 +124,12 @@ def __init__( def freeze_encoder(self, freeze=True): """ - freeze or unfreeze encoder model + Freezes or unfreezes the encoder's weights based on the input flag. + + Parameters + ---------- + freeze : bool, optional + If True, the encoder's weights are frozen, otherwise they are unfrozen. Default is True. """ if freeze: print("Freezing encoder weights...") @@ -88,11 +139,18 @@ def freeze_encoder(self, freeze=True): l.requires_grad = not freeze def unfreeze_encoder(self): + """ + Unfreezes the encoder's weights. Convenience method calling `freeze_encoder` with `False`. + """ self.freeze_encoder(False) def load(self, model_ckpt): """Load the model from checkpoint. The checkpoint dictionary must have a 'model' key with the saved model for value. + Parameters + ---------- + model_ckpt : str + The path to the checkpoint file containing the model's weights. """ print("Load model weights from", model_ckpt) if torch.cuda.is_available(): @@ -107,6 +165,19 @@ def load(self, model_ckpt): print(self.load_state_dict(ckpt['model'], strict=False)) def forward(self, x): + """ + Defines the forward pass of the UNet model. + + Parameters + ---------- + x : torch.Tensor + The input tensor representing the image to be segmented. + + Returns + ------- + torch.Tensor + The output segmentation map tensor. + """ # x is an image out = self.encoder(x) out = self.decoder(out) diff --git a/src/biom3d/omero_downloader.py b/src/biom3d/omero_downloader.py index 5162674..4525ab4 100644 --- a/src/biom3d/omero_downloader.py +++ b/src/biom3d/omero_downloader.py @@ -7,7 +7,7 @@ from omero.gateway import BlitzGateway from omero.cli import cli_login, CLI - +from omero.clients import BaseClient from omero.plugins.download import DownloadControl @@ -96,9 +96,14 @@ def download_datasets(conn, datasets, target_dir): continue dc.download_fileset(conn, fileset, dataset_dir) -def download_object(username, password, hostname, obj, target_dir): - conn = BlitzGateway(username=username, passwd=password, host=hostname, port=4064) - conn.connect() +def download_object(username, password, hostname, obj, target_dir, session_id=None): + if session_id is not None: + client = BaseClient(host=hostname, port=4064) + client.joinSession(session_id) + conn = BlitzGateway(client_obj=client) + else : + conn = BlitzGateway(username=username, passwd=password, host=hostname, port=4064) + conn.connect() try: obj_id = int(obj.split(":")[1]) obj_type = obj.split(":")[0] @@ -123,26 +128,71 @@ def download_object(username, password, hostname, obj, target_dir): download_datasets(conn, datasets, target_dir) - conn.close() + #conn.close() return datasets, target_dir +def download_attachment(hostname, username, password, session_id, attachment_id, config=True): + # Connect to the OMERO server using session ID or username/password + if session_id is not None: + client = BaseClient(host=hostname, port=4064) + client.joinSession(session_id) + conn = BlitzGateway(client_obj=client) + else: + conn = BlitzGateway(username=username, passwd=password, host=hostname, port=4064) + conn.connect() + + try: + # Get the FileAnnotation object by ID + annotation = conn.getObject("FileAnnotation", attachment_id) + if not annotation: + print(f"FileAnnotation with ID {attachment_id} not found.") + return + + # Get the linked OriginalFile object + original_file = annotation.getFile() + if original_file is None: + print(f'No OriginalFile linked to annotation ID {attachment_id}') + return + + file_id = original_file.id + file_name = original_file.name + file_size = original_file.size + + print(f"File ID: {file_id}, Name: {file_name}, Size: {file_size}") + + if config : file_path = os.path.join("configs", file_name) + else : file_path = os.path.join("logs", file_name) + + # Download the file data in chunks + print(f"\nDownloading file to {file_path}...") + with open(file_path, 'wb') as f: + for chunk in annotation.getFileInChunks(): + f.write(chunk) + return file_path + + finally: + # Close the connection + print("Downloaded!") def main(argv): parser = argparse.ArgumentParser() parser.add_argument('--obj', help="Download object: 'Project:ID' or 'Dataset:ID'") - parser.add_argument('--target', + parser.add_argument('--target', help="Directory name to download into") - parser.add_argument('--username', + parser.add_argument('--username', default=None, help="User name") - parser.add_argument('--password', + parser.add_argument('--password', default=None, help="Password") - parser.add_argument('--hostname', + parser.add_argument('--hostname',default=None, help="Host name") + parser.add_argument('--session_id',default=None, + help="Session ID") args = parser.parse_args(argv) - download_object(args.username, args.password, args.hostname, args.obj, args.target) + + download_object(args.username, args.password, args.hostname, args.obj, args.target, args.session_id) if __name__ == '__main__': main(sys.argv[1:]) \ No newline at end of file diff --git a/src/biom3d/omero_pred.py b/src/biom3d/omero_pred.py index c3c09fa..527fa8f 100644 --- a/src/biom3d/omero_pred.py +++ b/src/biom3d/omero_pred.py @@ -6,6 +6,7 @@ import argparse import os +import shutil from omero.cli import cli_login from omero.gateway import BlitzGateway @@ -16,10 +17,10 @@ pass from biom3d import pred -def run(obj, target, log, dir_out, host=None, user=None, pwd=None, upload_id=None, ext="_predictions"): +def run(obj, target, log, dir_out, attachment=None, host=None, user=None, pwd=None, upload_id=None,ext="_predictions", session_id=None): print("Start dataset/project downloading...") if host is not None: - datasets, dir_in = omero_downloader.download_object(user, pwd, host, obj, target) + datasets, dir_in = omero_downloader.download_object(user, pwd, host, obj, target, session_id) else: with cli_login() as cli: datasets, dir_in = omero_downloader.download_object_cli(cli, obj, target) @@ -43,7 +44,15 @@ def run(obj, target, log, dir_out, host=None, user=None, pwd=None, upload_id=Non dataset_name = os.path.basename(os.path.dirname(dir_in)) dataset_name += ext - omero_uploader.run(user, pwd, host,upload_id,dataset_name,dir_out) + omero_uploader.run(username=user,password= pwd,hostname= host,project=upload_id, attachment=attachment, is_pred=True, dataset_name=dataset_name,path=dir_out ,session_id=session_id) + # Remove all folders (pred, to_pred, attachment File) + + try : + shutil.rmtree(dir_in) + shutil.rmtree(dir_out) + os.remove(attachment+".zip") + except: + pass print("Done prediction!") # print for remote. Format TAG:key:value @@ -83,10 +92,14 @@ def run(obj, target, log, dir_out, host=None, user=None, pwd=None, upload_id=Non help="(optional) Password for Omero server") parser.add_argument('--upload_id', type=int, default=None, help="(optional) Id of Omero Project in which to upload the dataset. Only works with Omero Project Id and folder of images.") + parser.add_argument('--attachment', type=str, default=None, + help="(optional) Attachment file") # parser.add_argument("-e", "--eval_only", default=False, action='store_true', dest='eval_only', # help="Do only the evaluation and skip the prediction (predictions must have been done already.)") parser.add_argument('--ext', type=str, default='_predictions', help='Name of the extension added to the future uploaded Omero dataset.') + parser.add_argument('--session_id', default=None, + help="(optional) Session ID for Omero client") args = parser.parse_args() run( @@ -98,5 +111,7 @@ def run(obj, target, log, dir_out, host=None, user=None, pwd=None, upload_id=Non user=args.username, pwd=args.password, upload_id=args.upload_id, + attachment=args.attachment, ext=args.ext, + session_id=args.session_id ) \ No newline at end of file diff --git a/src/biom3d/omero_preprocess_train.py b/src/biom3d/omero_preprocess_train.py new file mode 100644 index 0000000..6224bb8 --- /dev/null +++ b/src/biom3d/omero_preprocess_train.py @@ -0,0 +1,282 @@ +#--------------------------------------------------------------------------- +# Predictions with Omero +# This script can download data from Omero, compute predictions, +# and upload back into Omero. +#--------------------------------------------------------------------------- + +import argparse +import os +import shutil +import zipfile +from omero.cli import cli_login +from biom3d import omero_downloader +from biom3d import omero_uploader +from biom3d import omero_pred +from biom3d import preprocess_train +from biom3d import preprocess +from biom3d import train + +def run(obj_raw, obj_mask, num_classes, config_dir, base_config, ct_norm, desc, max_dim, num_epochs, target , action, host=None, user=None, pwd=None, upload_id=None ,dir_out =None, omero_session_id=None): + + if action == "preprocess" or action=="preprocess_train" : + print("Start dataset/project downloading...") + if host is not None and omero_session_id is None: + datasets, dir_in = omero_downloader.download_object(user, pwd, host, obj_raw, target, omero_session_id) + if obj_mask is not None : + datasets_mask, dir_in_mask = omero_downloader.download_object(user, pwd, host, obj_mask, target, omero_session_id) + elif omero_session_id is not None and host is not None: + datasets, dir_in = omero_downloader.download_object(user, pwd, host, obj_raw, target, omero_session_id) + if obj_mask is not None : + datasets_mask, dir_in_mask = omero_downloader.download_object(user, pwd, host, obj_mask, target,omero_session_id) + else: + with cli_login() as cli: + datasets, dir_in = omero_downloader.download_object_cli(cli, obj_raw, target) + if obj_mask is not None : + datasets_mask, dir_in_mask = omero_downloader.download_object_cli(cli, obj_mask, target) + + print("Done downloading dataset!") + + + if 'Dataset' in obj_raw: + dir_in = os.path.join(dir_in, datasets[0].name) + dir_in_mask = os.path.join(dir_in_mask, datasets_mask[0].name) + + print("Start Training with Omero...") + if action == "preprocess_train" : + preprocess_train.preprocess_train( + img_dir=dir_in, + msk_dir=dir_in_mask, + num_classes=num_classes, + config_dir=config_dir, + base_config=base_config, + ct_norm=ct_norm, + desc=desc, + max_dim=max_dim, + num_epochs=num_epochs + ) + + elif action == "preprocess" : + config_path = preprocess.auto_config_preprocess( + img_dir=dir_in, + msk_dir=dir_in_mask, + num_classes=num_classes, + config_dir=config_dir, + base_config=base_config, + ct_norm=ct_norm, + desc=desc, + max_dim=max_dim, + num_epochs=num_epochs + ) + + elif action == "train" : + conf_dir =omero_downloader.download_attachment( + hostname=host, + username=user, + password=pwd, + session_id=omero_session_id, + attachment_id=config_dir, + config=True) + + print("Running training with current configuration file :",conf_dir) + + train.train(config=conf_dir) + try : + shutil.rmtree(conf_dir) + except: + pass + elif action == "pred" : + #Download the model + model =omero_downloader.download_attachment( + hostname=host, + username=user, + password=pwd, + session_id=omero_session_id, + attachment_id=config_dir, + config=False) + # extract the model + log_folder = unzip_file(model, os.path.join("logs")) + + target = "data/to_pred" + if not os.path.isdir(target): + os.makedirs(target, exist_ok=True) + + attachment_file, _ = os.path.splitext(os.path.basename(log_folder)) + + omero_pred.run( + obj=obj_raw, + log=log_folder, + dir_out=os.path.join("data","pred"), + host = host, + session_id=omero_session_id, + attachment=attachment_file, + upload_id=1, + target=target) + + try : + shutil.rmtree(log_folder) + os.remove(model) + except: + pass + # eventually upload the dataset back into Omero [DEPRECATED] + if upload_id is not None and host is not None: + + if action == "train" or action == "preprocess_train" : + # For Training + logs_path = "./logs" + if not os.path.exists(logs_path) : + print(f"Directory '{logs_path}' does not exist.") + else: + directories = [d for d in os.listdir(logs_path) if os.path.isdir(os.path.join(logs_path, d))] + if not directories: + print("No directories found in the logs path.") + else: + directories.sort(key=lambda d: os.path.getmtime(os.path.join(logs_path, d)), reverse=True) + last_folder = directories[0] + image_folder = os.path.join(logs_path, last_folder, "image") + plot_learning_curve(os.path.join(logs_path, last_folder)) + omero_uploader.run(username=user, password=pwd, hostname=host, project=upload_id, path = image_folder ,is_pred=False, attachment=last_folder, session_id =omero_session_id) + try : + os.remove(os.path.join(logs_path, last_folder+".zip")) + shutil.rmtree(os.path.join(logs_path, last_folder)) + except: + pass + shutil.rmtree(target) + + print("Done Training!") + # print for remote. Format TAG:key:value + print("REMOTE:dir_out:{}".format(dir_out)) + return dir_out + elif action == "preprocess" : + # For Preprocessing + last_folder = config_path + image_folder = None + print("last folder: ",last_folder) + print("image_folder : ",image_folder) + omero_uploader.run(username=user, password=pwd, hostname=host, project=upload_id, path = image_folder ,is_pred=False, attachment=last_folder, session_id =omero_session_id) + + else: + print("[Error] Type of object unknown {}. It should be 'Dataset' or 'Project'".format(obj_raw)) + + +def load_csv(filename): + from csv import reader + # Open file in read mode + file = open(filename,"r") + # Reading file + lines = reader(file) + + # Converting into a list + data = list(lines) + + return data + +def plot_learning_curve(last_folder): + import matplotlib.pyplot as plt + # CSV file path + print("this is it : ",last_folder) + csv_file = os.path.join(last_folder+"/log/log.csv") + + # PLOT + data = load_csv(csv_file) + # Extract epoch and train_loss, val_loss values + epochs = [int(row[0]) for row in data[1:]] # Skip the header row + train_losses = [float(row[1]) for row in data[1:]] # Skip the header row + val_losses = [float(row[2]) for row in data[1:]] # Skip the header row + + plt.clf() # Clear the current plot + plt.plot(epochs, train_losses ,label='Train loss') + plt.plot(epochs, val_losses , label ='Validation loss') + plt.xlabel('Epoch') + plt.ylabel('Loss') + plt.title('Learning Curves') + plt.grid(True) + plt.legend() + plt.pause(0.1) # Pause for a short duration to allow for updating + # save figure locally + plt.savefig(last_folder+'/image/Learning_curves_plot.png') + +def unzip_file(zip_path, extract_to): + """ + Unzips a zip file to a specified directory and returns the extraction directory including the name of the zip file. + + :param zip_path: Path to the zip file + :param extract_to: Directory to extract the files to + :return: The full path of the directory where the files were extracted + """ + # Get the base name of the zip file without extension + zip_base_name = os.path.splitext(os.path.basename(zip_path))[0] + + # Create the full extraction path + full_extract_path = os.path.join(extract_to, zip_base_name) + + # Ensure the extraction directory exists + if not os.path.exists(full_extract_path): + os.makedirs(full_extract_path) + + with zipfile.ZipFile(zip_path, 'r') as zip_ref: + zip_ref.extractall(full_extract_path) + + print(f"Extracted all files to {full_extract_path}") + return full_extract_path + +if __name__=='__main__': + + # parser + parser = argparse.ArgumentParser(description="Training with Omero.") + parser.add_argument('--raw', type=str, + help="Download Raw Dataset ") + parser.add_argument('--mask', type=str, + help="Download Masks Dataset ") + parser.add_argument('--target', type=str, default="data/to_train/", + help="Directory name to download into") + parser.add_argument('--action', type=str, default="preprocess_train", + help="Action : preprocess | train | preprocess_train ") + parser.add_argument("--num_classes", type=int, default=1, + help="(default=1) Number of classes (types of objects) in the dataset. The background is not included.") + parser.add_argument("--max_dim", type=int, default=128, + help="(default=128) max_dim^3 determines the maximum size of patch for auto-config.") + parser.add_argument("--num_epochs", type=int, default=1000, + help="(default=1000) Number of epochs for the training.") + parser.add_argument("--config_dir", type=str, default='configs/', + help="(default=\'configs/\') Configuration folder to save the auto-configuration.") + parser.add_argument("--base_config", type=str, default=None, + help="(default=None) Optional. Path to an existing configuration file which will be updated with the preprocessed values.") + parser.add_argument("--desc", type=str, default='unet_default', + help="(default=unet_default) Optional. A name used to describe the model.") + parser.add_argument("--ct_norm", default=False, action='store_true', dest='ct_norm', + help="(default=False) Whether to use CT-Scan normalization routine (cf. nnUNet).") + parser.add_argument('--hostname', type=str, default=None, + help="(optional) Host name for Omero server. If not mentioned use the CLI.") + parser.add_argument('--username', type=str, default=None, + help="(optional) User name for Omero server") + parser.add_argument('--password', type=str, default=None, + help="(optional) Password for Omero server") + parser.add_argument('--session_id', default=None, + help="(optional) Session ID for Omero client") + args = parser.parse_args() + + raw = "Dataset:"+args.raw + if args.action=="preprocess" or args.action=="preprocess_train": + mask = "Dataset:"+args.mask + else : + mask=None + + run( + obj_raw=raw, + obj_mask=mask, + num_classes=args.num_classes, + config_dir=args.config_dir, + base_config=args.base_config, + ct_norm=args.ct_norm, + desc=args.desc, + max_dim=args.max_dim, + num_epochs=args.num_epochs, + target=args.target, + action=args.action, + host=args.hostname, + user=args.username, + pwd=args.password, + upload_id=args.raw, + omero_session_id=args.session_id + ) + diff --git a/src/biom3d/omero_uploader.py b/src/biom3d/omero_uploader.py index b00e631..3faacd0 100644 --- a/src/biom3d/omero_uploader.py +++ b/src/biom3d/omero_uploader.py @@ -43,7 +43,8 @@ import os import platform import sys - +import zipfile + import omero.clients from omero.model import ChecksumAlgorithmI from omero.model import NamedValue @@ -52,6 +53,7 @@ from omero_version import omero_version from omero.callbacks import CmdCallbackI from omero.gateway import BlitzGateway +from omero.clients import BaseClient from ezomero import post_dataset def get_files_for_fileset(fs_path): @@ -68,7 +70,7 @@ def create_fileset(files): fileset = omero.model.FilesetI() for f in files: entry = omero.model.FilesetEntryI() - entry.setClientPath(rstring(f)) + entry.setClientPath(rstring(os.path.basename(f))) # Set only the filename fileset.addFilesetEntry(entry) # Fill version info @@ -91,7 +93,6 @@ def create_fileset(files): fileset.linkJob(upload) return fileset - def create_settings(): """Create ImportSettings and set some values.""" settings = omero.grid.ImportSettings() @@ -171,32 +172,80 @@ def full_import(client, fs_path, wait=-1): finally: proc.close() -def run(username, password, hostname, project, dataset_name, path, wait=-1): - conn = BlitzGateway(username=username, passwd=password, host=hostname, port=4064) - conn.connect() +def run(username, password, hostname, project, attachment, dataset_name=None, path=None, is_pred=False , wait=-1, session_id=None): + dataset_id = project + if session_id is not None: + client = BaseClient(host=hostname, port=4064) + client.joinSession(session_id) + conn = BlitzGateway(client_obj=client) + else: + conn = BlitzGateway(username=username, passwd=password, host=hostname, port=4064) + conn.connect() - if project and not conn.getObject('Project', project): + if project and is_pred and not conn.getObject('Project', project): print ('Project id not found: %s' % project) sys.exit(1) - # create a new Omero Dataset - dataset = post_dataset(conn,dataset_name,project) - - directory_path =str(path) - filees = get_files_for_fileset(directory_path) - for fs_path in filees: - print ('Importing: %s' % fs_path) - rsp = full_import(conn.c, fs_path, wait) - if rsp: - links = [] - for p in rsp.pixels: - print ('Imported Image ID: %d' % p.image.id.val) - if dataset: - link = omero.model.DatasetImageLinkI() - link.parent = omero.model.DatasetI(dataset, False) - link.child = omero.model.ImageI(p.image.id.val, False) - links.append(link) - conn.getUpdateService().saveArray(links, conn.SERVICE_OPTS) + if project and not is_pred : + # Get the dataset by ID + dataset = conn.getObject("Dataset", project) + dataset_name = dataset.getName()+"_trained" + parent_project = dataset.listParents() + project = parent_project[0].getId() + + if path is not None : + # create a new Omero Dataset + dataset = post_dataset(conn,dataset_name, project) + directory_path =str(path) + filees = get_files_for_fileset(directory_path) + for fs_path in filees: + print ('Importing: %s' % fs_path) + rsp = full_import(conn.c, fs_path, wait) + if rsp: + links = [] + for p in rsp.pixels: + print ('Imported Image ID: %d' % p.image.id.val) + if dataset: + link = omero.model.DatasetImageLinkI() + link.parent = omero.model.DatasetI(dataset, False) + link.child = omero.model.ImageI(p.image.id.val, False) + links.append(link) + conn.getUpdateService().saveArray(links, conn.SERVICE_OPTS) + dataset_id = dataset + + + if attachment is not None: + if path is not None: + logs_path = "./logs" + last_folder_path = os.path.join(logs_path, "{}".format(attachment)) + zip_file_path = os.path.join(logs_path, "{}.zip".format(attachment)) + # Create a zip file excluding the "image" folder + with zipfile.ZipFile(zip_file_path, 'w', zipfile.ZIP_DEFLATED) as zipf: + for root, dirs, files in os.walk(last_folder_path): + # Exclude the "image" directory + if 'image' in dirs: + dirs.remove('image') + for file in files: + file_path = os.path.join(root, file) + arcname = os.path.relpath(file_path, start=last_folder_path) + zipf.write(file_path, arcname) + + print(f"Zipped folder (excluding 'image' folder): {zip_file_path}") + + dataset = conn.getObject("Dataset", dataset_id) + # Specify a local file e.g. could be result of some analysis + file_to_upload = zip_file_path if path is not None else attachment # This file should already exist + + # create the original file and file annotation (uploads the file etc.) + + print("\nCreating an OriginalFile and FileAnnotation") + file_ann = conn.createFileAnnfromLocalFile( + file_to_upload, mimetype="text/plain", desc=None) + print("Attaching FileAnnotation to Dataset: ", "File ID:", file_ann.getId(), \ + ",", file_ann.getFile().getName(), "Size:", file_ann.getFile().getSize()) + dataset.linkAnnotation(file_ann) # link it to dataset. + + conn.close() if __name__ == '__main__': @@ -206,7 +255,7 @@ def run(username, password, hostname, project, dataset_name, path, wait=-1): parser.add_argument('--wait', type=int, default=-1, help=( 'Wait for this number of seconds for each import to complete. ' '0: return immediately, -1: wait indefinitely (default)')) - parser.add_argument('--dataset_name', + parser.add_argument('--dataset_name',default="Biom3d_pred", help='Name of the Omero dataset.') parser.add_argument('--path', help='Files or directories') @@ -216,13 +265,22 @@ def run(username, password, hostname, project, dataset_name, path, wait=-1): help="Password") parser.add_argument('--hostname', help="Host name") + parser.add_argument('--attachment', default=None, + help="Attachment file") + parser.add_argument('--is_pred', default=False, + help="Check Whether its a prediction or a training ") + parser.add_argument('--session_id', default=None, + help="Omero Session id") args = parser.parse_args() - run(args.username, args.password, args.hostname, + run(args.username, args.password, host=args.hostname, project=args.project, dataset_name=args.dataset_name, path=args.path, - wait=args.wait + wait=args.wait, + attachment=args.attachment, + is_pred=args.is_pred, + session_id=args.session_id, ) diff --git a/src/biom3d/predictors.py b/src/biom3d/predictors.py index 44cb714..2c72474 100644 --- a/src/biom3d/predictors.py +++ b/src/biom3d/predictors.py @@ -143,6 +143,8 @@ def __init__( intensity_moments=[], ): + + self.fname = fname self.patch_size = np.array(patch_size) self.median_spacing = np.array(median_spacing) diff --git a/src/biom3d/preprocess.py b/src/biom3d/preprocess.py index eae94f7..1a506c3 100644 --- a/src/biom3d/preprocess.py +++ b/src/biom3d/preprocess.py @@ -154,10 +154,15 @@ def sanity_check(msk, num_classes=None): assert num_classes >= 2 if len(msk.shape)==4: + if msk.shape[0]==1: + return sanity_check(msk[0], num_classes=num_classes) # if we have 4 dimensions in the mask, we consider it one-hot encoded # and thus we perform a sanity check for each channel - for i in range(msk.shape[0]): - sanity_check(msk[i], num_classes=2) + else: + new_msk = [] + for i in range(msk.shape[0]): + new_msk+=[sanity_check(msk[i], num_classes=2)] + return np.array(new_msk) cls = np.arange(num_classes) if np.array_equal(uni,cls): @@ -205,7 +210,7 @@ def seg_preprocessor( intensity_moments=[], channel_axis=0, num_channels=1, - ): + is_2d=False): """Segmentation pre-processing. """ do_msk = msk is not None @@ -216,31 +221,41 @@ def seg_preprocessor( if do_msk: # sanity check msk = sanity_check(msk, num_classes) - - # expand image dim - if len(img.shape)==3: - # keep the input shape, used for preprocessing before prediction + # add dimension if it's a 2d image + if len(img.shape) == 2 : + img = np.expand_dims(img, axis=(0,1)) + # Expand image dimension, we consider the Z dim as the smallest dimension ( we put it in the second position [C,Z,Y,X]) + if is_2d and len(img.shape)==3: original_shape = img.shape - img = np.expand_dims(img, 0) - elif len(img.shape)==4: - # we consider as the channel dimension, the smallest dimension - # if it is the last dim, then we move it to the first - # the size of other dimensions of the image should be bigger than the channel dim. - if np.argmin(img.shape)==channel_axis and img.shape[channel_axis]==num_channels: - img = np.swapaxes(img, 0, channel_axis) - else: - print("[Error] Invalid image shape:", img.shape) + img = np.expand_dims(img, 1) - # keep the input shape, used for preprocessing before prediction - original_shape = img.shape - else: - raise ValueError("[Error] Invalid image shape for 3D image {}. Skipping image...".format(img.shape)) + else : + # expand image dim + if len(img.shape)==3: + # keep the input shape, used for preprocessing before prediction + original_shape = img.shape + img = np.expand_dims(img, 0) + elif len(img.shape)==4: + # we consider as the channel dimension, the smallest dimension + # if it is the last dim, then we move it to the first + # the size of other dimensions of the image should be bigger than the channel dim. + if np.argmin(img.shape)==channel_axis and img.shape[channel_axis]==num_channels: + img = np.swapaxes(img, 0, channel_axis) + else: + print("[Error] Invalid image shape:", img.shape) + + # keep the input shape, used for preprocessing before prediction + original_shape = img.shape + else: + raise ValueError("[Error] Invalid image shape for 3D image {}. Skipping image...".format(img.shape)) assert img.shape[0]==num_channels, "[Error] Invalid image shape {}. Expected to have {} numbers of channel at {} channel axis.".format(img.shape, num_channels, channel_axis) # one hot encoding for the mask if needed if do_msk and len(msk.shape)!=4: - if use_one_hot: + if len(msk.shape) == 2: + msk = np.expand_dims(msk,axis=(0,1)) + elif use_one_hot: msk = one_hot_fast(msk, num_classes) if remove_bg: msk = msk[1:] @@ -368,15 +383,16 @@ def __init__( use_tif=False, # use tif instead of npy split_rate_for_single_img=0.25, num_kfolds=5, + is_2d=False, ): assert img_dir!='', "[Error] img_dir must not be empty." - + # fix bug path/folder/ to path/folder if os.path.basename(img_dir)=='': img_dir = os.path.dirname(img_dir) if msk_dir is not None and os.path.basename(msk_dir)=='': msk_dir = os.path.dirname(msk_dir) - + self.is_2d = is_2d self.img_dir=img_dir self.msk_dir=msk_dir self.img_fnames=sorted(os.listdir(self.img_dir)) @@ -417,7 +433,10 @@ def __init__( self.num_channels = 1 self.channel_axis = 0 - + # Make sure the Channel dim is first ( we assume that the channel dim is the smallest dimension ! ) + if is_2d and len(self.median_size)==3: + self.num_channels = np.min(median_size) + self.channel_axis = np.argmin(self.median_size) # if the 3D image has 4 dimensions then there is a channel dimension. if len(self.median_size)==4: # the channel dimension is consider to be the smallest dimension @@ -568,6 +587,7 @@ def run(self, debug=False): intensity_moments =self.intensity_moments, channel_axis =self.channel_axis, num_channels =self.num_channels, + is_2d=self.is_2d, ) else: img, _ = seg_preprocessor( @@ -578,6 +598,7 @@ def run(self, debug=False): intensity_moments =self.intensity_moments, channel_axis =self.channel_axis, num_channels =self.num_channels, + is_2d=self.is_2d, ) # sanity check to be sure that all images have the save number of channel @@ -649,6 +670,7 @@ def auto_config_preprocess( logs_dir='logs/', print_param=False, debug=False, + is_2d=False, ): """Helper function to do auto-config and preprocessing. """ @@ -662,6 +684,7 @@ def auto_config_preprocess( print("Standard deviation of intensities:", std) print("0.5% percentile of intensities:", perc_005) print("99.5% percentile of intensities:", perc_995) + print("Image Type: " + ("2D" if is_2d else "3D")) print("") if ct_norm: @@ -690,6 +713,7 @@ def auto_config_preprocess( median_size=median_size, clipping_bounds=clipping_bounds, intensity_moments=intensity_moments, + is_2d=is_2d, ) if not skip_preprocessing: @@ -705,7 +729,10 @@ def auto_config_preprocess( max_dims=(max_dim, max_dim, max_dim), max_batch = len(os.listdir(img_dir))//20, # we limit batch to avoid overfitting ) - + # make sure the Z dim of the patch is equal to 1 ! + if is_2d : + patch[0] = 1 + aug_patch[0] = 1 # convert path for windows systems before writing them if platform=='win32': if p.img_outdir is not None: p.img_outdir = p.img_outdir.replace('\\','\\\\') @@ -735,6 +762,7 @@ def auto_config_preprocess( DESC=desc, NB_EPOCHS=num_epochs, LOG_DIR=logs_dir, + IS_2D=is_2d ) if not print_param: print("Auto-config done! Configuration saved in: ", config_path) @@ -788,6 +816,8 @@ def auto_config_preprocess( help="(default=False) Whether to print auto-config parameters. Used for remote preprocessing using the GUI.") parser.add_argument("--debug", default=False, action='store_true', dest='debug', help="(default=False) Debug mode. Whether to print all image filenames while preprocessing.") + parser.add_argument("--is_2d", default=False, + help="(default=False) Check whether the image has 2d only.") args = parser.parse_args() auto_config_preprocess( @@ -810,6 +840,7 @@ def auto_config_preprocess( logs_dir=args.logs_dir, print_param=args.remote, debug=args.debug, + is_2d=args.is_2d, ) #--------------------------------------------------------------------------- diff --git a/src/biom3d/register.py b/src/biom3d/register.py index c091263..bfdeea2 100644 --- a/src/biom3d/register.py +++ b/src/biom3d/register.py @@ -27,12 +27,16 @@ #--------------------------------------------------------------------------- # model register -from biom3d.models.unet3d_vgg_deep import UNet from biom3d.models.encoder_vgg import VGGEncoder, EncoderBlock +from biom3d.models.unet3d_vgg_deep import UNet +from biom3d.models.encoder_efficientnet3d import EfficientNet3D +from biom3d.models.unet3d_eff import EffUNet models = Dict( - UNet3DVGGDeep =Dict(fct=UNet, kwargs=Dict()), VGG3D =Dict(fct=VGGEncoder, kwargs=Dict(block=EncoderBlock, use_head=True)), + UNet3DVGGDeep =Dict(fct=UNet, kwargs=Dict()), + Eff3D =Dict(fct=EfficientNet3D.from_name, kwargs=Dict()), + EffUNet =Dict(fct=EffUNet, kwargs=Dict()), ) #--------------------------------------------------------------------------- diff --git a/src/biom3d/utils.py b/src/biom3d/utils.py index 1fb9657..eca3c17 100644 --- a/src/biom3d/utils.py +++ b/src/biom3d/utils.py @@ -272,10 +272,9 @@ def tif_read_imagej(img_path, axes_order='CZYX'): img_meta : dict Image metadata. """ - with tiff.TiffFile(img_path) as tif: - assert tif.is_imagej - + assert tif.is_imagej + # store img_meta img_meta = {} @@ -309,7 +308,8 @@ def tif_read_imagej(img_path, axes_order='CZYX'): img = tiff.tifffile.transpose_axes(img, series.axes, axes_order) - img_meta["axes"] = axes_order + #img_meta["axes"] = axes_order + img_meta["axes"] = series.axes return img, img_meta @@ -1286,4 +1286,4 @@ def __str__(self): self.start_time=time() return "[DEBUG] name: {}, count: {}, time: {} seconds".format(self.name, self.count, out) -# ---------------------------------------------------------------------------- +# ---------------------------------------------------------------------------- \ No newline at end of file