From 04edf0ccd43dd304b22abdc72d403a53eee0be36 Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 29 Jul 2022 10:37:19 +0800 Subject: [PATCH 01/26] add logging to classification agent --- pymic/net_run/agent_abstract.py | 3 ++- pymic/net_run/agent_cls.py | 13 +++++++------ pymic/net_run/agent_seg.py | 5 +++-- 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/pymic/net_run/agent_abstract.py b/pymic/net_run/agent_abstract.py index 420e2f9..fce71f1 100644 --- a/pymic/net_run/agent_abstract.py +++ b/pymic/net_run/agent_abstract.py @@ -3,6 +3,7 @@ import os import random +import logging import torch import numpy as np import torch.optim as optim @@ -42,7 +43,7 @@ def __init__(self, config, stage = 'train'): self.random_seed = config['training'].get('random_seed', 1) if(self.deterministic): seed_torch(self.random_seed) - print("deterministric is true") + logging.info("deterministric is true") def set_datasets(self, train_set, valid_set, test_set): self.train_set = train_set diff --git a/pymic/net_run/agent_cls.py b/pymic/net_run/agent_cls.py index 4a80532..71d7c30 100644 --- a/pymic/net_run/agent_cls.py +++ b/pymic/net_run/agent_cls.py @@ -3,6 +3,7 @@ import copy import csv +import logging import time import torch from torchvision import transforms @@ -71,7 +72,7 @@ def create_network(self): else: self.net.double() param_number = sum(p.numel() for p in self.net.parameters() if p.requires_grad) - print('parameter number:', param_number) + logging.info('parameter number {0:}'.format(param_number)) def get_parameters_to_update(self): params = self.net.get_parameters_to_update() @@ -176,10 +177,10 @@ def write_scalars(self, train_scalars, valid_scalars, glob_it): self.summ_writer.add_scalars('loss', loss_scalar, glob_it) self.summ_writer.add_scalars(metrics, acc_scalar, glob_it) - print("{0:} it {1:}".format(str(datetime.now())[:-7], glob_it)) - print('train loss {0:.4f}, avg {1:} {2:.4f}'.format( + logging.info("{0:} it {1:}".format(str(datetime.now())[:-7], glob_it)) + logging.info('train loss {0:.4f}, avg {1:} {2:.4f}'.format( train_scalars['loss'], metrics, train_scalars[metrics])) - print('valid loss {0:.4f}, avg {1:} {2:.4f}'.format( + logging.info('valid loss {0:.4f}, avg {1:} {2:.4f}'.format( valid_scalars['loss'], metrics, valid_scalars[metrics])) def train_valid(self): @@ -218,7 +219,7 @@ def train_valid(self): self.trainIter = iter(self.train_loader) - print("{0:} training start".format(str(datetime.now())[:-7])) + logging.info("{0:} training start".format(str(datetime.now())[:-7])) self.summ_writer = SummaryWriter(self.config['training']['ckpt_save_dir']) for it in range(iter_start, iter_max, iter_valid): train_scalars = self.training() @@ -252,7 +253,7 @@ def train_valid(self): txt_file = open("{0:}/{1:}_best.txt".format(ckpt_dir, ckpt_prefx), 'wt') txt_file.write(str(self.max_val_it)) txt_file.close() - print('The best perfroming iter is {0:}, valid {1:} {2:}'.format(\ + logging.info('The best perfroming iter is {0:}, valid {1:} {2:}'.format(\ self.max_val_it, metrics, self.max_val_score)) self.summ_writer.close() diff --git a/pymic/net_run/agent_seg.py b/pymic/net_run/agent_seg.py index e4bb97c..e1d4cc4 100644 --- a/pymic/net_run/agent_seg.py +++ b/pymic/net_run/agent_seg.py @@ -323,6 +323,7 @@ def train_valid(self): t0 = time.time() train_scalars = self.training() t1 = time.time() + valid_scalars = self.validation() t2 = time.time() self.glob_it = it + iter_valid @@ -428,7 +429,7 @@ def test_time_dropout(m): self.save_ouputs(data) infer_time_list = np.asarray(infer_time_list) time_avg, time_std = infer_time_list.mean(), infer_time_list.std() - print("testing time {0:} +/- {1:}".format(time_avg, time_std)) + logging.info("testing time {0:} +/- {1:}".format(time_avg, time_std)) def infer_with_multiple_checkpoints(self): """ @@ -482,7 +483,7 @@ def infer_with_multiple_checkpoints(self): self.save_ouputs(data) infer_time_list = np.asarray(infer_time_list) time_avg, time_std = infer_time_list.mean(), infer_time_list.std() - print("testing time {0:} +/- {1:}".format(time_avg, time_std)) + logging.info("testing time {0:} +/- {1:}".format(time_avg, time_std)) def save_ouputs(self, data): output_dir = self.config['testing']['output_dir'] From c2bbec5522d4c1a7f1ac0fa29d1d2d7e764946d6 Mon Sep 17 00:00:00 2001 From: taigw Date: Sat, 30 Jul 2022 10:25:29 +0800 Subject: [PATCH 02/26] add lr scheduler --- pymic/net_run/agent_abstract.py | 18 ++++++------- pymic/net_run/agent_seg.py | 33 ++++++++++++++++-------- pymic/net_run/get_optimizer.py | 45 ++++++++++++++++++++++++--------- pymic/util/general.py | 27 ++++++++++++++++++++ 4 files changed, 92 insertions(+), 31 deletions(-) create mode 100644 pymic/util/general.py diff --git a/pymic/net_run/agent_abstract.py b/pymic/net_run/agent_abstract.py index fce71f1..d701534 100644 --- a/pymic/net_run/agent_abstract.py +++ b/pymic/net_run/agent_abstract.py @@ -8,7 +8,7 @@ import numpy as np import torch.optim as optim from abc import ABCMeta, abstractmethod -from pymic.net_run.get_optimizer import get_optimiser +from pymic.net_run.get_optimizer import get_lr_scheduler, get_optimizer def seed_torch(seed=1): random.seed(seed) @@ -72,7 +72,9 @@ def get_checkpoint_name(self): ckpt_mode = self.config['testing']['ckpt_mode'] if(ckpt_mode == 0 or ckpt_mode == 1): ckpt_dir = self.config['training']['ckpt_save_dir'] - ckpt_prefix = ckpt_dir.split('/')[-1] + ckpt_prefix = self.config['training'].get('ckpt_prefix', None) + if(ckpt_prefix is None): + ckpt_prefix = ckpt_dir.split('/')[-1] txt_name = ckpt_dir + '/' + ckpt_prefix txt_name += "_latest.txt" if ckpt_mode == 0 else "_best.txt" with open(txt_name, 'r') as txt_file: @@ -146,19 +148,17 @@ def worker_init_fn(worker_id): batch_size = bn_test, shuffle=False, num_workers= bn_test) def create_optimizer(self, params): + opt_params = self.config['training'] if(self.optimizer is None): - self.optimizer = get_optimiser(self.config['training']['optimizer'], - params, - self.config['training']) + self.optimizer = get_optimizer(opt_params['optimizer'], + params, opt_params) last_iter = -1 if(self.checkpoint is not None): self.optimizer.load_state_dict(self.checkpoint['optimizer_state_dict']) last_iter = self.checkpoint['iteration'] - 1 if(self.scheduler is None): - self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer, - self.config['training']['lr_milestones'], - self.config['training']['lr_gamma'], - last_epoch = last_iter) + opt_params["laster_iter"] = last_iter + self.scheduler = get_lr_scheduler(self.optimizer, opt_params) def convert_tensor_type(self, input_tensor): if(self.tensor_type == 'float'): diff --git a/pymic/net_run/agent_seg.py b/pymic/net_run/agent_seg.py index e1d4cc4..6da26cc 100644 --- a/pymic/net_run/agent_seg.py +++ b/pymic/net_run/agent_seg.py @@ -26,6 +26,7 @@ from pymic.loss.seg.util import get_classwise_dice from pymic.transform.trans_dict import TransformDict from pymic.util.image_process import convert_label +from pymic.util.general import keyword_match class SegmentationAgent(NetRunAgent): def __init__(self, config, stage = 'train'): @@ -192,10 +193,10 @@ def training(self): # forward + backward + optimize outputs = self.net(inputs) loss = self.get_loss_value(data, outputs, labels_prob) - # if (self.config['training']['use']) loss.backward() self.optimizer.step() - self.scheduler.step() + if(not keyword_match(self.config['training']['lr_scheduler'], "ReduceLROnPlateau")): + self.scheduler.step() train_loss = train_loss + loss.item() # get dice evaluation for each class @@ -251,15 +252,19 @@ def validation(self): valid_cls_dice = np.asarray(valid_dice_list).mean(axis = 0) valid_avg_dice = valid_cls_dice.mean() + if(keyword_match(self.config['training']['lr_scheduler'], "ReduceLROnPlateau")): + self.scheduler.step(valid_avg_dice) + valid_scalers = {'loss': valid_avg_loss, 'avg_dice': valid_avg_dice,\ 'class_dice': valid_cls_dice} return valid_scalers - def write_scalars(self, train_scalars, valid_scalars, glob_it): + def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): loss_scalar ={'train':train_scalars['loss'], 'valid':valid_scalars['loss']} dice_scalar ={'train':train_scalars['avg_dice'], 'valid':valid_scalars['avg_dice']} self.summ_writer.add_scalars('loss', loss_scalar, glob_it) self.summ_writer.add_scalars('dice', dice_scalar, glob_it) + self.summ_writer.add_scalars('lr', {"lr": lr_value}, glob_it) class_num = self.config['network']['class_num'] for c in range(class_num): cls_dice_scalar = {'train':train_scalars['class_dice'][c], \ @@ -282,13 +287,14 @@ def train_valid(self): self.device = torch.device("cuda:{0:}".format(device_ids[0])) self.net.to(self.device) ckpt_dir = self.config['training']['ckpt_save_dir'] - if(ckpt_dir[-1] == "/"): - ckpt_dir = ckpt_dir[:-1] - ckpt_prefx = ckpt_dir.split('/')[-1] + ckpt_prefix = self.config['training'].get('ckpt_prefix', None) + if(ckpt_prefix is None): + ckpt_prefix = ckpt_dir.split('/')[-1] iter_start = self.config['training']['iter_start'] iter_max = self.config['training']['iter_max'] iter_valid = self.config['training']['iter_valid'] iter_save = self.config['training']['iter_save'] + early_stop_it = self.config['training'].get('early_stop_patience', None) if(isinstance(iter_save, (tuple, list))): iter_save_list = iter_save else: @@ -299,7 +305,7 @@ def train_valid(self): self.best_model_wts = None self.checkpoint = None if(iter_start > 0): - checkpoint_file = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefx, iter_start) + checkpoint_file = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, iter_start) self.checkpoint = torch.load(checkpoint_file, map_location = self.device) # assert(self.checkpoint['iteration'] == iter_start) if(len(device_ids) > 1): @@ -320,6 +326,7 @@ def train_valid(self): self.summ_writer = SummaryWriter(self.config['training']['ckpt_save_dir']) self.glob_it = iter_start for it in range(iter_start, iter_max, iter_valid): + lr_value = self.optimizer.param_groups[0]['lr'] t0 = time.time() train_scalars = self.training() t1 = time.time() @@ -327,9 +334,10 @@ def train_valid(self): valid_scalars = self.validation() t2 = time.time() self.glob_it = it + iter_valid - logging.info("{0:} it {1:}".format(str(datetime.now())[:-7], self.glob_it)) + logging.info("\n{0:} it {1:}".format(str(datetime.now())[:-7], self.glob_it)) + logging.info('learning rate {0:}'.format(lr_value)) logging.info("training/validation time: {0:.2f}s/{1:.2f}s".format(t1-t0, t2-t1)) - self.write_scalars(train_scalars, valid_scalars, self.glob_it) + self.write_scalars(train_scalars, valid_scalars, lr_value, self.glob_it) if(valid_scalars['avg_dice'] > self.max_val_dice): self.max_val_dice = valid_scalars['avg_dice'] self.max_val_it = self.glob_it @@ -338,7 +346,9 @@ def train_valid(self): else: self.best_model_wts = copy.deepcopy(self.net.state_dict()) - if (self.glob_it in iter_save_list): + stop_now = True if(early_stop_it is not None and \ + self.glob_it - self.max_val_it > early_stop_it) else False + if ((self.glob_it in iter_save_list) or stop_now): save_dict = {'iteration': self.glob_it, 'valid_pred': valid_scalars['avg_dice'], 'model_state_dict': self.net.module.state_dict() \ @@ -349,6 +359,9 @@ def train_valid(self): txt_file = open("{0:}/{1:}_latest.txt".format(ckpt_dir, ckpt_prefx), 'wt') txt_file.write(str(self.glob_it)) txt_file.close() + if(stop_now): + logging.info("The training is early stopped") + break # save the best performing checkpoint save_dict = {'iteration': self.max_val_it, 'valid_pred': self.max_val_dice, diff --git a/pymic/net_run/get_optimizer.py b/pymic/net_run/get_optimizer.py index 7170b6e..e475286 100644 --- a/pymic/net_run/get_optimizer.py +++ b/pymic/net_run/get_optimizer.py @@ -2,33 +2,54 @@ from __future__ import print_function, division import torch -import torch.optim as optim +from torch import optim +from torch.optim import lr_scheduler +from pymic.util.general import keyword_match -def get_optimiser(name, net_params, optim_params): +def get_optimizer(name, net_params, optim_params): lr = optim_params['learning_rate'] momentum = optim_params['momentum'] weight_decay = optim_params['weight_decay'] - if(name == "SGD"): + if(keyword_match(name, "SGD")): return optim.SGD(net_params, lr, momentum = momentum, weight_decay = weight_decay) - elif(name == "Adam"): + elif(keyword_match(name, "Adam")): return optim.Adam(net_params, lr, weight_decay = weight_decay) - elif(name == "SparseAdam"): + elif(keyword_match(name, "SparseAdam")): return optim.SparseAdam(net_params, lr) - elif(name == "Adadelta"): + elif(keyword_match(name, "Adadelta")): return optim.Adadelta(net_params, lr, weight_decay = weight_decay) - elif(name == "Adagrad"): + elif(keyword_match(name, "Adagrad")): return optim.Adagrad(net_params, lr, weight_decay = weight_decay) - elif(name == "Adamax"): + elif(keyword_match(name, "Adamax")): return optim.Adamax(net_params, lr, weight_decay = weight_decay) - elif(name == "ASGD"): + elif(keyword_match(name, "ASGD")): return optim.ASGD(net_params, lr, weight_decay = weight_decay) - elif(name == "LBFGS"): + elif(keyword_match(name, "LBFGS")): return optim.LBFGS(net_params, lr) - elif(name == "RMSprop"): + elif(keyword_match(name, "RMSprop")): return optim.RMSprop(net_params, lr, momentum = momentum, weight_decay = weight_decay) - elif(name == "Rprop"): + elif(keyword_match(name, "Rprop")): return optim.Rprop(net_params, lr) else: raise ValueError("unsupported optimizer {0:}".format(name)) + + +def get_lr_scheduler(optimizer, sched_params): + name = sched_params["lr_scheduler"] + lr_gamma = sched_params["lr_gamma"] + if(keyword_match(name, "ReduceLROnPlateau")): + patience_it = sched_params["ReduceLROnPlateau_patience".lower()] + val_it = sched_params["iter_valid"] + patience = patience_it / val_it + scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, + mode = "max", factor=lr_gamma, patience = patience) + elif(keyword_match(name, "MultiStepLR")): + lr_milestones = sched_params["lr_milestones"] + last_iter = sched_params["last_iter"] + scheduler = lr_scheduler.MultiStepLR(optimizer, + lr_milestones, lr_gamma, last_iter) + else: + raise ValueError("unsupported lr scheduler {0:}".format(name)) + return scheduler \ No newline at end of file diff --git a/pymic/util/general.py b/pymic/util/general.py new file mode 100644 index 0000000..063d654 --- /dev/null +++ b/pymic/util/general.py @@ -0,0 +1,27 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division +import torch +import numpy as np + +def keyword_match(a,b): + return a.lower() == b.lower() + +def get_one_hot_seg(label, class_num): + """ + convert a segmentation label to one-hot + label: a tensor with a shape of [N, 1, D, H, W] or [N, 1, H, W] + class_num: class number. + output: an one-hot tensor with a shape of [N, C, D, H, W] or [N, C, H, W] + """ + size = list(label.size()) + if(size[1] != 1): + raise ValueError("The channel should be 1, \ + rather than {0:} before one-hot encoding".format(size[1])) + label = label.view(-1) + ones = torch.sparse.torch.eye(class_num).to(label.device) + one_hot = ones.index_select(0, label) + size.append(class_num) + one_hot = one_hot.view(*size) + one_hot = torch.transpose(one_hot, 1, -1) + one_hot = torch.squeeze(one_hot, -1) + return one_hot \ No newline at end of file From bde766242c45b938b73d42bd1f8858df47564581 Mon Sep 17 00:00:00 2001 From: taigw Date: Sat, 30 Jul 2022 14:01:28 +0800 Subject: [PATCH 03/26] add early stop and update readme set early stop in agent_seg update readme for annotation-efficient learning --- README.md | 18 ++++++++++-------- pymic/net_run/agent_seg.py | 16 +++++++++------- 2 files changed, 19 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index d0aa89e..d6007e8 100644 --- a/README.md +++ b/README.md @@ -1,29 +1,31 @@ # PyMIC: A Pytorch-Based Toolkit for Medical Image Computing -PyMIC is a pytorch-based toolkit for medical image computing with deep learning. Despite that pytorch is a fantastic platform for deep learning, using it for medical image computing is not straightforward as medical images are often with higher dimension, multiple modalities and low contrast. The toolkit is developed to facilitate medical image computing researchers so that training and testing deep learning models become easier. It is very friendly to researchers who are new to this area. Even without writing any code, you can use PyMIC commands to train and test a model by simply editing configure files. +PyMIC is a pytorch-based toolkit for medical image computing with annotation-efficient deep learning. Despite that pytorch is a fantastic platform for deep learning, using it for medical image computing is not straightforward as medical images are often with high dimension and large volume, multiple modalities and difficulies in annotating. This toolkit is developed to facilitate medical image computing researchers so that training and testing deep learning models become easier. It is very friendly to researchers who are new to this area. Even without writing any code, you can use PyMIC commands to train and test a model by simply editing configuration files. PyMIC is developed to support learning with imperfect labels, including semi-supervised and weakly supervised learning, and learning with noisy annotations. Currently PyMIC supports 2D/3D medical image classification and segmentation, and it is still under development. It was originally developed for COVID-19 pneumonia lesion segmentation from CT images. If you use this toolkit, please cite the following paper: - * G. Wang, X. Liu, C. Li, Z. Xu, J. Ruan, H. Zhu, T. Meng, K. Li, N. Huang, S. Zhang. [A Noise-robust Framework for Automatic Segmentation of COVID-19 Pneumonia Lesions from CT Images.][tmi2020] IEEE Transactions on Medical Imaging. 39(8):2653-2663, 2020. DOI: [10.1109/TMI.2020.3000314][tmi2020] [tmi2020]:https://ieeexplore.ieee.org/document/9109297 -# Advantages -PyMIC provides some basic modules for medical image computing that can be share by different applications. We currently provide the following functions: +# Features +PyMIC provides flixible modules for medical image computing tasks including classification and segmentation. It currently provides the following functions: +* Support for annotation-efficient image segmentation, especially for semi-supervised, weakly-supervised and noisy-label learning. +* User friendly: For beginners, you only need to edit the configuration files for model training and inference, without writing code. For advanced users, you can customize different modules (networks, loss functions, training pipeline, etc) and easily integrate them into PyMIC. * Easy-to-use I/O interface to read and write different 2D and 3D images. +* Various data pre-processing/transformation methods before sending a tensor into a network. +* Implementation of typical neural networks for medical image segmentation. * Re-useable training and testing pipeline that can be transferred to different tasks. -* Various data pre-processing methods before sending a tensor into a network. -* Implementation of loss functions, especially for image segmentation. -* Implementation of evaluation metrics to get quantitative evaluation of your methods (for segmentation). +* Evaluation metrics for quantitative evaluation of your methods. # Usage ## Requirement * [Pytorch][torch_link] version >=1.0.1 * [TensorboardX][tbx_link] to visualize training performance * Some common python packages such as Numpy, Pandas, SimpleITK +* See `requirements.txt` for details. [torch_link]:https://pytorch.org/ [tbx_link]:https://github.com/lanpa/tensorboardX @@ -42,7 +44,7 @@ python setup.py install ``` ## Examples -[PyMIC_examples][examples] provides some examples of starting to use PyMIC. For beginners, you only need to simply change the configuration files to select different datasets, networks and training methods for running the code. For advanced users, you can develop your own modules based on this package. You can find both types of examples +[PyMIC_examples][examples] provides some examples of starting to use PyMIC. At the beginning, you only need to edit the configuration files to select different datasets, networks and training methods for running the code. When you are more familiar with PyMIC, you can customize different modules in the PyMIC package. You can find both types of examples: [examples]: https://github.com/HiLab-git/PyMIC_examples diff --git a/pymic/net_run/agent_seg.py b/pymic/net_run/agent_seg.py index 6da26cc..0cc59b3 100644 --- a/pymic/net_run/agent_seg.py +++ b/pymic/net_run/agent_seg.py @@ -293,12 +293,14 @@ def train_valid(self): iter_start = self.config['training']['iter_start'] iter_max = self.config['training']['iter_max'] iter_valid = self.config['training']['iter_valid'] - iter_save = self.config['training']['iter_save'] + iter_save = self.config['training'].get('iter_save', None) early_stop_it = self.config['training'].get('early_stop_patience', None) - if(isinstance(iter_save, (tuple, list))): + if(iter_save is None): + iter_save_list = [iter_max] + elif(isinstance(iter_save, (tuple, list))): iter_save_list = iter_save else: - iter_save_list = range(iter_start, iter_max +1, iter_save) + iter_save_list = range(iter_start, iter_max + 1, iter_save) self.max_val_dice = 0.0 self.max_val_it = 0 @@ -354,9 +356,9 @@ def train_valid(self): 'model_state_dict': self.net.module.state_dict() \ if len(device_ids) > 1 else self.net.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict()} - save_name = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefx, self.glob_it) + save_name = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, self.glob_it) torch.save(save_dict, save_name) - txt_file = open("{0:}/{1:}_latest.txt".format(ckpt_dir, ckpt_prefx), 'wt') + txt_file = open("{0:}/{1:}_latest.txt".format(ckpt_dir, ckpt_prefix), 'wt') txt_file.write(str(self.glob_it)) txt_file.close() if(stop_now): @@ -367,9 +369,9 @@ def train_valid(self): 'valid_pred': self.max_val_dice, 'model_state_dict': self.best_model_wts, 'optimizer_state_dict': self.optimizer.state_dict()} - save_name = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefx, self.max_val_it) + save_name = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, self.max_val_it) torch.save(save_dict, save_name) - txt_file = open("{0:}/{1:}_best.txt".format(ckpt_dir, ckpt_prefx), 'wt') + txt_file = open("{0:}/{1:}_best.txt".format(ckpt_dir, ckpt_prefix), 'wt') txt_file.write(str(self.max_val_it)) txt_file.close() logging.info('The best performing iter is {0:}, valid dice {1:}'.format(\ From 15b227a6724afefc7dae623c6ab545a674333a86 Mon Sep 17 00:00:00 2001 From: taigw Date: Sun, 31 Jul 2022 16:13:43 +0800 Subject: [PATCH 04/26] update intensity transform add gaussian noise --- pymic/transform/gamma_correction.py | 40 ----------------- pymic/transform/intensity.py | 70 +++++++++++++++++++++++++++++ pymic/transform/trans_dict.py | 6 ++- 3 files changed, 74 insertions(+), 42 deletions(-) delete mode 100644 pymic/transform/gamma_correction.py create mode 100644 pymic/transform/intensity.py diff --git a/pymic/transform/gamma_correction.py b/pymic/transform/gamma_correction.py deleted file mode 100644 index 4a88f1c..0000000 --- a/pymic/transform/gamma_correction.py +++ /dev/null @@ -1,40 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import print_function, division - -import torch -import json -import math -import random -import numpy as np -from scipy import ndimage -from pymic.transform.abstract_transform import AbstractTransform -from pymic.util.image_process import * - - -class ChannelWiseGammaCorrection(AbstractTransform): - """ - apply random gamma correction to each channel - """ - def __init__(self, params): - """ - (gamma_min, gamma_max) specify the range of gamma - """ - super(ChannelWiseGammaCorrection, self).__init__(params) - self.gamma_min = params['ChannelWiseGammaCorrection_gamma_min'.lower()] - self.gamma_max = params['ChannelWiseGammaCorrection_gamma_max'.lower()] - self.inverse = params.get('ChannelWiseGammaCorrection_inverse'.lower(), False) - - def __call__(self, sample): - image= sample['image'] - for chn in range(image.shape[0]): - gamma_c = random.random() * (self.gamma_max - self.gamma_min) + self.gamma_min - img_c = image[chn] - v_min = img_c.min() - v_max = img_c.max() - img_c = (img_c - v_min)/(v_max - v_min) - img_c = np.power(img_c, gamma_c)*(v_max - v_min) + v_min - image[chn] = img_c - - sample['image'] = image - return sample - diff --git a/pymic/transform/intensity.py b/pymic/transform/intensity.py new file mode 100644 index 0000000..b9e6070 --- /dev/null +++ b/pymic/transform/intensity.py @@ -0,0 +1,70 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import torch +import json +import math +import random +import numpy as np +from scipy import ndimage +from pymic.transform.abstract_transform import AbstractTransform +from pymic.util.image_process import * + + +class GammaCorrection(AbstractTransform): + """ + apply random gamma correction to each channel + """ + def __init__(self, params): + """ + (gamma_min, gamma_max) specify the range of gamma + """ + super(GammaCorrection, self).__init__(params) + self.channels = params['GammaCorrection_channels'.lower()] + self.gamma_min = params['GammaCorrection_gamma_min'.lower()] + self.gamma_max = params['GammaCorrection_gamma_max'.lower()] + self.prob = params.get('GammaCorrection_probability'.lower(), 0.5) + self.inverse = params.get('GammaCorrection_inverse'.lower(), False) + + def __call__(self, sample): + if(np.random.uniform() > self.prob): + return sample + image= sample['image'] + for chn in self.channels: + gamma_c = random.random() * (self.gamma_max - self.gamma_min) + self.gamma_min + img_c = image[chn] + v_min = img_c.min() + v_max = img_c.max() + img_c = (img_c - v_min)/(v_max - v_min) + img_c = np.power(img_c, gamma_c)*(v_max - v_min) + v_min + image[chn] = img_c + + sample['image'] = image + return sample + +class GaussianNoise(AbstractTransform): + """ + apply random gamma correction to each channel + """ + def __init__(self, params): + """ + (gamma_min, gamma_max) specify the range of gamma + """ + super(GaussianNoise, self).__init__(params) + self.channels = params['GaussianNoise_channels'.lower()] + self.mean = params['GaussianNoise_mean'.lower()] + self.std = params['GaussianNoise_std'.lower()] + self.prob = params.get('GaussianNoise_probability'.lower(), 0.5) + self.inverse = params.get('GaussianNoise_inverse'.lower(), False) + + def __call__(self, sample): + if(np.random.uniform() > self.prob): + return sample + image= sample['image'] + for chn in self.channels: + img_c = image[chn] + noise = np.random.normal(self.mean, self.std, img_c.shape) + image[chn] = img_c + noise + + sample['image'] = image + return sample diff --git a/pymic/transform/trans_dict.py b/pymic/transform/trans_dict.py index ae9ce9c..d90e431 100644 --- a/pymic/transform/trans_dict.py +++ b/pymic/transform/trans_dict.py @@ -1,8 +1,9 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division -from pymic.transform.gamma_correction import ChannelWiseGammaCorrection +from pymic.transform.intensity import * from pymic.transform.gray2rgb import GrayscaleToRGB from pymic.transform.flip import RandomFlip +from pymic.transform.intensity import GaussianNoise from pymic.transform.pad import Pad from pymic.transform.rotate import RandomRotate from pymic.transform.rescale import Rescale, RandomRescale @@ -12,12 +13,13 @@ from pymic.transform.label_convert import * TransformDict = { - 'ChannelWiseGammaCorrection': ChannelWiseGammaCorrection, 'ChannelWiseThreshold': ChannelWiseThreshold, 'ChannelWiseThresholdWithNormalize': ChannelWiseThresholdWithNormalize, 'CropWithBoundingBox': CropWithBoundingBox, 'CenterCrop': CenterCrop, 'GrayscaleToRGB': GrayscaleToRGB, + 'GammaCorrection': GammaCorrection, + 'GaussianNoise': GaussianNoise, 'LabelConvert': LabelConvert, 'LabelConvertNonzero': LabelConvertNonzero, 'LabelToProbability': LabelToProbability, From 376f2108bc2c8293f467697b4946aea6346e9810 Mon Sep 17 00:00:00 2001 From: taigw Date: Mon, 1 Aug 2022 10:00:52 +0800 Subject: [PATCH 05/26] Update infer_func.py set pre-defined tta for 2d images --- pymic/net_run/infer_func.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/pymic/net_run/infer_func.py b/pymic/net_run/infer_func.py index e603725..35bfb4c 100644 --- a/pymic/net_run/infer_func.py +++ b/pymic/net_run/infer_func.py @@ -131,24 +131,26 @@ def run(self, model, image): tta_mode = self.config.get('tta_mode', 0) if(tta_mode == 0): outputs = self.__infer(image) - elif(tta_mode == 1): # test time augmentation with flip in 2D + elif(tta_mode == 1): + # test time augmentation with flip in 2D + # you may define your own method for test time augmentation outputs1 = self.__infer(image) outputs2 = self.__infer(torch.flip(image, [-2])) - outputs3 = self.__infer(torch.flip(image, [-3])) - outputs4 = self.__infer(torch.flip(image, [-2, -3])) + outputs3 = self.__infer(torch.flip(image, [-1])) + outputs4 = self.__infer(torch.flip(image, [-2, -1])) if(isinstance(outputs1, (tuple, list))): outputs = [] for i in range(len(outputs)): temp_out1 = outputs1[i] temp_out2 = torch.flip(outputs2[i], [-2]) - temp_out3 = torch.flip(outputs3[i], [-3]) - temp_out4 = torch.flip(outputs4[i], [-2, -3]) + temp_out3 = torch.flip(outputs3[i], [-1]) + temp_out4 = torch.flip(outputs4[i], [-2, -1]) temp_mean = (temp_out1 + temp_out2 + temp_out3 + temp_out4) / 4 outputs.append(temp_mean) else: outputs2 = torch.flip(outputs2, [-2]) - outputs3 = torch.flip(outputs3, [-3]) - outputs4 = torch.flip(outputs4, [-2, -3]) + outputs3 = torch.flip(outputs3, [-1]) + outputs4 = torch.flip(outputs4, [-2, -1]) outputs = (outputs1 + outputs2 + outputs3 + outputs4) / 4 else: raise ValueError("Undefined tta_mode {0:}".format(tta_mode)) From 31113ed7cd14d84ea45f4af4ac255e6a66a8bf2e Mon Sep 17 00:00:00 2001 From: taigw Date: Mon, 1 Aug 2022 16:41:00 +0800 Subject: [PATCH 06/26] update ssl and nll rename net_run_noise as net_run_nll update ssl_abstract --- pymic/{net_run_noise => net_run_nll}/cl.py | 0 .../co_teaching.py | 0 pymic/net_run_ssl/ssl_abstract.py | 101 ++++++++++++++++++ pymic/net_run_ssl/ssl_cps.py | 4 +- pymic/net_run_ssl/ssl_em.py | 80 +------------- pymic/net_run_ssl/ssl_mt.py | 4 +- pymic/net_run_ssl/ssl_urpc.py | 4 +- 7 files changed, 110 insertions(+), 83 deletions(-) rename pymic/{net_run_noise => net_run_nll}/cl.py (100%) rename pymic/{net_run_noise => net_run_nll}/co_teaching.py (100%) create mode 100644 pymic/net_run_ssl/ssl_abstract.py diff --git a/pymic/net_run_noise/cl.py b/pymic/net_run_nll/cl.py similarity index 100% rename from pymic/net_run_noise/cl.py rename to pymic/net_run_nll/cl.py diff --git a/pymic/net_run_noise/co_teaching.py b/pymic/net_run_nll/co_teaching.py similarity index 100% rename from pymic/net_run_noise/co_teaching.py rename to pymic/net_run_nll/co_teaching.py diff --git a/pymic/net_run_ssl/ssl_abstract.py b/pymic/net_run_ssl/ssl_abstract.py new file mode 100644 index 0000000..acb3b5a --- /dev/null +++ b/pymic/net_run_ssl/ssl_abstract.py @@ -0,0 +1,101 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division +import logging +import numpy as np +import random +import torch +import torchvision.transforms as transforms +from pymic.io.nifty_dataset import NiftyDataset +from pymic.loss.seg.util import get_soft_label +from pymic.loss.seg.util import reshape_prediction_and_ground_truth +from pymic.loss.seg.util import get_classwise_dice +from pymic.loss.seg.ssl import EntropyLoss +from pymic.net_run.agent_seg import SegmentationAgent +from pymic.transform.trans_dict import TransformDict +from pymic.util.ramps import sigmoid_rampup + +class SSLSegAgent(SegmentationAgent): + """ + Implementation of the following paper: + Yves Grandvalet and Yoshua Bengio, + Semi-supervised Learningby Entropy Minimization. + NeurIPS, 2005. + """ + def __init__(self, config, stage = 'train'): + super(SSLSegAgent, self).__init__(config, stage) + self.transform_dict = TransformDict + self.train_set_unlab = None + + def get_unlabeled_dataset_from_config(self): + root_dir = self.config['dataset']['root_dir'] + modal_num = self.config['dataset']['modal_num'] + transform_names = self.config['dataset']['train_transform_unlab'] + + self.transform_list = [] + if(transform_names is None or len(transform_names) == 0): + data_transform = None + else: + transform_param = self.config['dataset'] + transform_param['task'] = 'segmentation' + for name in transform_names: + if(name not in self.transform_dict): + raise(ValueError("Undefined transform {0:}".format(name))) + one_transform = self.transform_dict[name](transform_param) + self.transform_list.append(one_transform) + data_transform = transforms.Compose(self.transform_list) + + csv_file = self.config['dataset'].get('train_csv_unlab', None) + dataset = NiftyDataset(root_dir=root_dir, + csv_file = csv_file, + modal_num = modal_num, + with_label= False, + transform = data_transform ) + return dataset + + def create_dataset(self): + super(SSLSegAgent, self).create_dataset() + if(self.stage == 'train'): + if(self.train_set_unlab is None): + self.train_set_unlab = self.get_unlabeled_dataset_from_config() + if(self.deterministic): + def worker_init_fn(worker_id): + random.seed(self.random_seed+worker_id) + worker_init = worker_init_fn + else: + worker_init = None + + bn_train_unlab = self.config['dataset']['train_batch_size_unlab'] + num_worker = self.config['dataset'].get('num_workder', 16) + self.train_loader_unlab = torch.utils.data.DataLoader(self.train_set_unlab, + batch_size = bn_train_unlab, shuffle=True, num_workers= num_worker, + worker_init_fn=worker_init) + + def training(self): + pass + + def write_scalars(self, train_scalars, valid_scalars, glob_it): + loss_scalar ={'train':train_scalars['loss'], + 'valid':valid_scalars['loss']} + loss_sup_scalar = {'train':train_scalars['loss_sup']} + loss_upsup_scalar = {'train':train_scalars['loss_reg']} + dice_scalar ={'train':train_scalars['avg_dice'], 'valid':valid_scalars['avg_dice']} + self.summ_writer.add_scalars('loss', loss_scalar, glob_it) + self.summ_writer.add_scalars('loss_sup', loss_sup_scalar, glob_it) + self.summ_writer.add_scalars('loss_reg', loss_upsup_scalar, glob_it) + self.summ_writer.add_scalars('regular_w', {'regular_w':train_scalars['regular_w']}, glob_it) + self.summ_writer.add_scalars('dice', dice_scalar, glob_it) + class_num = self.config['network']['class_num'] + for c in range(class_num): + cls_dice_scalar = {'train':train_scalars['class_dice'][c], \ + 'valid':valid_scalars['class_dice'][c]} + self.summ_writer.add_scalars('class_{0:}_dice'.format(c), cls_dice_scalar, glob_it) + logging.info('train loss {0:.4f}, avg dice {1:.4f} '.format( + train_scalars['loss'], train_scalars['avg_dice']) + "[" + \ + ' '.join("{0:.4f}".format(x) for x in train_scalars['class_dice']) + "]") + logging.info('valid loss {0:.4f}, avg dice {1:.4f} '.format( + valid_scalars['loss'], valid_scalars['avg_dice']) + "[" + \ + ' '.join("{0:.4f}".format(x) for x in valid_scalars['class_dice']) + "]") + + def train_valid(self): + self.trainIter_unlab = iter(self.train_loader_unlab) + super(SSLSegAgent, self).train_valid() diff --git a/pymic/net_run_ssl/ssl_cps.py b/pymic/net_run_ssl/ssl_cps.py index e9bc41b..49a6e11 100644 --- a/pymic/net_run_ssl/ssl_cps.py +++ b/pymic/net_run_ssl/ssl_cps.py @@ -9,10 +9,10 @@ from pymic.loss.seg.util import get_classwise_dice from pymic.util.ramps import sigmoid_rampup from pymic.net_run.get_optimizer import get_optimiser -from pymic.net_run_ssl.ssl_em import SSLEntropyMinimization +from pymic.net_run_ssl.ssl_abstract import SSLSegAgent from pymic.net.net_dict_seg import SegNetDict -class SSLCrossPseudoSupervision(SSLEntropyMinimization): +class SSLCrossPseudoSupervision(SSLSegAgent): """ Using cross pseudo supervision according to the following paper: Xiaokang Chen, Yuhui Yuan, Gang Zeng, Jingdong Wang, diff --git a/pymic/net_run_ssl/ssl_em.py b/pymic/net_run_ssl/ssl_em.py index 32f85c1..28a9dec 100644 --- a/pymic/net_run_ssl/ssl_em.py +++ b/pymic/net_run_ssl/ssl_em.py @@ -2,19 +2,16 @@ from __future__ import print_function, division import logging import numpy as np -import random import torch -import torchvision.transforms as transforms -from pymic.io.nifty_dataset import NiftyDataset from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.loss.seg.ssl import EntropyLoss -from pymic.net_run.agent_seg import SegmentationAgent +from pymic.net_run_ssl.ssl_abstract import SSLSegAgent from pymic.transform.trans_dict import TransformDict from pymic.util.ramps import sigmoid_rampup -class SSLEntropyMinimization(SegmentationAgent): +class SSLEntropyMinimization(SSLSegAgent): """ Implementation of the following paper: Yves Grandvalet and Yoshua Bengio, @@ -26,50 +23,6 @@ def __init__(self, config, stage = 'train'): self.transform_dict = TransformDict self.train_set_unlab = None - def get_unlabeled_dataset_from_config(self): - root_dir = self.config['dataset']['root_dir'] - modal_num = self.config['dataset']['modal_num'] - transform_names = self.config['dataset']['train_transform_unlab'] - - self.transform_list = [] - if(transform_names is None or len(transform_names) == 0): - data_transform = None - else: - transform_param = self.config['dataset'] - transform_param['task'] = 'segmentation' - for name in transform_names: - if(name not in self.transform_dict): - raise(ValueError("Undefined transform {0:}".format(name))) - one_transform = self.transform_dict[name](transform_param) - self.transform_list.append(one_transform) - data_transform = transforms.Compose(self.transform_list) - - csv_file = self.config['dataset'].get('train_csv_unlab', None) - dataset = NiftyDataset(root_dir=root_dir, - csv_file = csv_file, - modal_num = modal_num, - with_label= False, - transform = data_transform ) - return dataset - - def create_dataset(self): - super(SSLEntropyMinimization, self).create_dataset() - if(self.stage == 'train'): - if(self.train_set_unlab is None): - self.train_set_unlab = self.get_unlabeled_dataset_from_config() - if(self.deterministic): - def worker_init_fn(worker_id): - random.seed(self.random_seed+worker_id) - worker_init = worker_init_fn - else: - worker_init = None - - bn_train_unlab = self.config['dataset']['train_batch_size_unlab'] - num_worker = self.config['dataset'].get('num_workder', 16) - self.train_loader_unlab = torch.utils.data.DataLoader(self.train_set_unlab, - batch_size = bn_train_unlab, shuffle=True, num_workers= num_worker, - worker_init_fn=worker_init) - def training(self): class_num = self.config['network']['class_num'] iter_valid = self.config['training']['iter_valid'] @@ -142,31 +95,4 @@ def training(self): train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, 'loss_reg':train_avg_loss_reg, 'regular_w':regular_w, 'avg_dice':train_avg_dice, 'class_dice': train_cls_dice} - return train_scalers - - def write_scalars(self, train_scalars, valid_scalars, glob_it): - loss_scalar ={'train':train_scalars['loss'], - 'valid':valid_scalars['loss']} - loss_sup_scalar = {'train':train_scalars['loss_sup']} - loss_upsup_scalar = {'train':train_scalars['loss_reg']} - dice_scalar ={'train':train_scalars['avg_dice'], 'valid':valid_scalars['avg_dice']} - self.summ_writer.add_scalars('loss', loss_scalar, glob_it) - self.summ_writer.add_scalars('loss_sup', loss_sup_scalar, glob_it) - self.summ_writer.add_scalars('loss_reg', loss_upsup_scalar, glob_it) - self.summ_writer.add_scalars('regular_w', {'regular_w':train_scalars['regular_w']}, glob_it) - self.summ_writer.add_scalars('dice', dice_scalar, glob_it) - class_num = self.config['network']['class_num'] - for c in range(class_num): - cls_dice_scalar = {'train':train_scalars['class_dice'][c], \ - 'valid':valid_scalars['class_dice'][c]} - self.summ_writer.add_scalars('class_{0:}_dice'.format(c), cls_dice_scalar, glob_it) - logging.info('train loss {0:.4f}, avg dice {1:.4f} '.format( - train_scalars['loss'], train_scalars['avg_dice']) + "[" + \ - ' '.join("{0:.4f}".format(x) for x in train_scalars['class_dice']) + "]") - logging.info('valid loss {0:.4f}, avg dice {1:.4f} '.format( - valid_scalars['loss'], valid_scalars['avg_dice']) + "[" + \ - ' '.join("{0:.4f}".format(x) for x in valid_scalars['class_dice']) + "]") - - def train_valid(self): - self.trainIter_unlab = iter(self.train_loader_unlab) - super(SSLEntropyMinimization, self).train_valid() + return train_scalers \ No newline at end of file diff --git a/pymic/net_run_ssl/ssl_mt.py b/pymic/net_run_ssl/ssl_mt.py index d25edbc..b96e0b1 100644 --- a/pymic/net_run_ssl/ssl_mt.py +++ b/pymic/net_run_ssl/ssl_mt.py @@ -7,10 +7,10 @@ from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.util.ramps import sigmoid_rampup -from pymic.net_run_ssl.ssl_em import SSLEntropyMinimization +from pymic.net_run_ssl.ssl_abstract import SSLSegAgent from pymic.net.net_dict_seg import SegNetDict -class SSLMeanTeacher(SSLEntropyMinimization): +class SSLMeanTeacher(SSLSegAgent): """ Training and testing agent for semi-supervised segmentation """ diff --git a/pymic/net_run_ssl/ssl_urpc.py b/pymic/net_run_ssl/ssl_urpc.py index d2d5a1f..d9dd953 100644 --- a/pymic/net_run_ssl/ssl_urpc.py +++ b/pymic/net_run_ssl/ssl_urpc.py @@ -8,9 +8,9 @@ from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.util.ramps import sigmoid_rampup -from pymic.net_run_ssl.ssl_em import SSLEntropyMinimization +from pymic.net_run_ssl.ssl_abstract import SSLSegAgent -class SSLURPC(SSLEntropyMinimization): +class SSLURPC(SSLSegAgent): """ Uncertainty-Rectified Pyramid Consistency according to the following paper: Xiangde Luo, Wenjun Liao, Jieneng Chen, Tao Song, Yinan Chen, From 58109748221f2ffb705b21c368665b181bcac050 Mon Sep 17 00:00:00 2001 From: taigw Date: Mon, 1 Aug 2022 16:58:27 +0800 Subject: [PATCH 07/26] update wsl rename the classes --- pymic/net_run_wsl/wsl_abstract.py | 37 +++++++++++++++++++++++++++ pymic/net_run_wsl/wsl_dmpls.py | 10 +++----- pymic/net_run_wsl/wsl_em.py | 34 +++--------------------- pymic/net_run_wsl/wsl_gatedcrf.py | 10 +++----- pymic/net_run_wsl/wsl_mumford_shah.py | 10 +++----- pymic/net_run_wsl/wsl_tv.py | 10 +++----- pymic/net_run_wsl/wsl_ustm.py | 12 +++------ 7 files changed, 57 insertions(+), 66 deletions(-) create mode 100644 pymic/net_run_wsl/wsl_abstract.py diff --git a/pymic/net_run_wsl/wsl_abstract.py b/pymic/net_run_wsl/wsl_abstract.py new file mode 100644 index 0000000..fe80ea5 --- /dev/null +++ b/pymic/net_run_wsl/wsl_abstract.py @@ -0,0 +1,37 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division +import logging +from pymic.net_run.agent_seg import SegmentationAgent + +class WSLSegAgent(SegmentationAgent): + """ + Training and testing agent for semi-supervised segmentation + """ + def __init__(self, config, stage = 'train'): + super(WSLSegAgent, self).__init__(config, stage) + + def training(self): + pass + + def write_scalars(self, train_scalars, valid_scalars, glob_it): + loss_scalar ={'train':train_scalars['loss'], + 'valid':valid_scalars['loss']} + loss_sup_scalar = {'train':train_scalars['loss_sup']} + loss_upsup_scalar = {'train':train_scalars['loss_reg']} + dice_scalar ={'train':train_scalars['avg_dice'], 'valid':valid_scalars['avg_dice']} + self.summ_writer.add_scalars('loss', loss_scalar, glob_it) + self.summ_writer.add_scalars('loss_sup', loss_sup_scalar, glob_it) + self.summ_writer.add_scalars('loss_reg', loss_upsup_scalar, glob_it) + self.summ_writer.add_scalars('regular_w', {'regular_w':train_scalars['regular_w']}, glob_it) + self.summ_writer.add_scalars('dice', dice_scalar, glob_it) + class_num = self.config['network']['class_num'] + for c in range(class_num): + cls_dice_scalar = {'train':train_scalars['class_dice'][c], \ + 'valid':valid_scalars['class_dice'][c]} + self.summ_writer.add_scalars('class_{0:}_dice'.format(c), cls_dice_scalar, glob_it) + logging.info('train loss {0:.4f}, avg dice {1:.4f} '.format( + train_scalars['loss'], train_scalars['avg_dice']) + "[" + \ + ' '.join("{0:.4f}".format(x) for x in train_scalars['class_dice']) + "]") + logging.info('valid loss {0:.4f}, avg dice {1:.4f} '.format( + valid_scalars['loss'], valid_scalars['avg_dice']) + "[" + \ + ' '.join("{0:.4f}".format(x) for x in valid_scalars['class_dice']) + "]") diff --git a/pymic/net_run_wsl/wsl_dmpls.py b/pymic/net_run_wsl/wsl_dmpls.py index c42ed7a..234f1a2 100644 --- a/pymic/net_run_wsl/wsl_dmpls.py +++ b/pymic/net_run_wsl/wsl_dmpls.py @@ -4,18 +4,14 @@ import numpy as np import random import torch -import torchvision.transforms as transforms -from pymic.io.nifty_dataset import NiftyDataset from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.loss.seg.dice import DiceLoss -from pymic.loss.seg.ssl import TotalVariationLoss -from pymic.net_run.agent_seg import SegmentationAgent -from pymic.net_run_wsl.wsl_em import WSL_EntropyMinimization +from pymic.net_run_wsl.wsl_abstract import WSLSegAgent from pymic.util.ramps import sigmoid_rampup -class WSL_DMPLS(WSL_EntropyMinimization): +class WSLDMPLS(WSLSegAgent): """ Implementation of the following paper: Xiangde Luo, Minhao Hu, Wenjun Liao, Shuwei Zhai, Tao Song, Guotai Wang, @@ -28,7 +24,7 @@ def __init__(self, config, stage = 'train'): if net_type not in ['DualBranchUNet2D', 'DualBranchUNet3D']: raise ValueError("""For WSL_DMPLS, a dual branch network is expected. \ It only supports DualBranchUNet2D and DualBranchUNet3D currently.""") - super(WSL_DMPLS, self).__init__(config, stage) + super(WSLDMPLS, self).__init__(config, stage) def training(self): class_num = self.config['network']['class_num'] diff --git a/pymic/net_run_wsl/wsl_em.py b/pymic/net_run_wsl/wsl_em.py index 9534504..cc19600 100644 --- a/pymic/net_run_wsl/wsl_em.py +++ b/pymic/net_run_wsl/wsl_em.py @@ -2,24 +2,21 @@ from __future__ import print_function, division import logging import numpy as np -import random import torch -import torchvision.transforms as transforms -from pymic.io.nifty_dataset import NiftyDataset from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.loss.seg.ssl import EntropyLoss from pymic.net_run.agent_seg import SegmentationAgent -from pymic.transform.trans_dict import TransformDict +from pymic.net_run_wsl.wsl_abstract import WSLSegAgent from pymic.util.ramps import sigmoid_rampup -class WSL_EntropyMinimization(SegmentationAgent): +class WSLEntropyMinimization(WSLSegAgent): """ Training and testing agent for semi-supervised segmentation """ def __init__(self, config, stage = 'train'): - super(WSL_EntropyMinimization, self).__init__(config, stage) + super(WSLEntropyMinimization, self).__init__(config, stage) def training(self): class_num = self.config['network']['class_num'] @@ -85,27 +82,4 @@ def training(self): train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, 'loss_reg':train_avg_loss_reg, 'regular_w':regular_w, 'avg_dice':train_avg_dice, 'class_dice': train_cls_dice} - return train_scalers - - def write_scalars(self, train_scalars, valid_scalars, glob_it): - loss_scalar ={'train':train_scalars['loss'], - 'valid':valid_scalars['loss']} - loss_sup_scalar = {'train':train_scalars['loss_sup']} - loss_upsup_scalar = {'train':train_scalars['loss_reg']} - dice_scalar ={'train':train_scalars['avg_dice'], 'valid':valid_scalars['avg_dice']} - self.summ_writer.add_scalars('loss', loss_scalar, glob_it) - self.summ_writer.add_scalars('loss_sup', loss_sup_scalar, glob_it) - self.summ_writer.add_scalars('loss_reg', loss_upsup_scalar, glob_it) - self.summ_writer.add_scalars('regular_w', {'regular_w':train_scalars['regular_w']}, glob_it) - self.summ_writer.add_scalars('dice', dice_scalar, glob_it) - class_num = self.config['network']['class_num'] - for c in range(class_num): - cls_dice_scalar = {'train':train_scalars['class_dice'][c], \ - 'valid':valid_scalars['class_dice'][c]} - self.summ_writer.add_scalars('class_{0:}_dice'.format(c), cls_dice_scalar, glob_it) - logging.info('train loss {0:.4f}, avg dice {1:.4f} '.format( - train_scalars['loss'], train_scalars['avg_dice']) + "[" + \ - ' '.join("{0:.4f}".format(x) for x in train_scalars['class_dice']) + "]") - logging.info('valid loss {0:.4f}, avg dice {1:.4f} '.format( - valid_scalars['loss'], valid_scalars['avg_dice']) + "[" + \ - ' '.join("{0:.4f}".format(x) for x in valid_scalars['class_dice']) + "]") + return train_scalers \ No newline at end of file diff --git a/pymic/net_run_wsl/wsl_gatedcrf.py b/pymic/net_run_wsl/wsl_gatedcrf.py index 8e9c6de..af6c562 100644 --- a/pymic/net_run_wsl/wsl_gatedcrf.py +++ b/pymic/net_run_wsl/wsl_gatedcrf.py @@ -2,24 +2,20 @@ from __future__ import print_function, division import logging import numpy as np -import random import torch -import torchvision.transforms as transforms -from pymic.io.nifty_dataset import NiftyDataset from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.loss.seg.gatedcrf import ModelLossSemsegGatedCRF -from pymic.net_run.agent_seg import SegmentationAgent -from pymic.net_run_wsl.wsl_em import WSL_EntropyMinimization +from pymic.net_run_wsl.wsl_abstract import WSLSegAgent from pymic.util.ramps import sigmoid_rampup -class WSL_GatedCRF(WSL_EntropyMinimization): +class WSLGatedCRF(WSLSegAgent): """ Training and testing agent for semi-supervised segmentation """ def __init__(self, config, stage = 'train'): - super(WSL_GatedCRF, self).__init__(config, stage) + super(WSLGatedCRF, self).__init__(config, stage) # parameters for gated CRF wsl_cfg = self.config['weakly_supervised_learning'] w0 = wsl_cfg.get('GatedCRFLoss_W0'.lower(), 1.0) diff --git a/pymic/net_run_wsl/wsl_mumford_shah.py b/pymic/net_run_wsl/wsl_mumford_shah.py index 909a65b..f642e59 100644 --- a/pymic/net_run_wsl/wsl_mumford_shah.py +++ b/pymic/net_run_wsl/wsl_mumford_shah.py @@ -2,24 +2,20 @@ from __future__ import print_function, division import logging import numpy as np -import random import torch -import torchvision.transforms as transforms -from pymic.io.nifty_dataset import NiftyDataset from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.loss.seg.mumford_shah import MumfordShahLoss -from pymic.net_run.agent_seg import SegmentationAgent -from pymic.net_run_wsl.wsl_em import WSL_EntropyMinimization +from pymic.net_run_wsl.wsl_abstract import WSLSegAgent from pymic.util.ramps import sigmoid_rampup -class WSL_MumfordShah(WSL_EntropyMinimization): +class WSLMumfordShah(WSLSegAgent): """ Training and testing agent for semi-supervised segmentation """ def __init__(self, config, stage = 'train'): - super(WSL_MumfordShah, self).__init__(config, stage) + super(WSLMumfordShah, self).__init__(config, stage) def training(self): class_num = self.config['network']['class_num'] diff --git a/pymic/net_run_wsl/wsl_tv.py b/pymic/net_run_wsl/wsl_tv.py index f11c5e0..9492150 100644 --- a/pymic/net_run_wsl/wsl_tv.py +++ b/pymic/net_run_wsl/wsl_tv.py @@ -2,24 +2,20 @@ from __future__ import print_function, division import logging import numpy as np -import random import torch -import torchvision.transforms as transforms -from pymic.io.nifty_dataset import NiftyDataset from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.loss.seg.ssl import TotalVariationLoss -from pymic.net_run.agent_seg import SegmentationAgent -from pymic.net_run_wsl.wsl_em import WSL_EntropyMinimization +from pymic.net_run_wsl.wsl_abstract import WSLSegAgent from pymic.util.ramps import sigmoid_rampup -class WSL_TotalVariation(WSL_EntropyMinimization): +class WSLTotalVariation(WSLSegAgent): """ Training and testing agent for semi-supervised segmentation """ def __init__(self, config, stage = 'train'): - super(WSL_TotalVariation, self).__init__(config, stage) + super(WSLTotalVariation, self).__init__(config, stage) def training(self): class_num = self.config['network']['class_num'] diff --git a/pymic/net_run_wsl/wsl_ustm.py b/pymic/net_run_wsl/wsl_ustm.py index c7306e8..32e79df 100644 --- a/pymic/net_run_wsl/wsl_ustm.py +++ b/pymic/net_run_wsl/wsl_ustm.py @@ -5,27 +5,23 @@ import random import torch import torch.nn.functional as F -import torchvision.transforms as transforms -from pymic.io.nifty_dataset import NiftyDataset from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice -from pymic.loss.seg.ssl import EntropyLoss from pymic.net.net_dict_seg import SegNetDict -from pymic.net_run_wsl.wsl_em import WSL_EntropyMinimization -from pymic.transform.trans_dict import TransformDict +from pymic.net_run_wsl.wsl_abstract import WSLSegAgent from pymic.util.ramps import sigmoid_rampup -class WSL_USTM(WSL_EntropyMinimization): +class WSLUSTM(WSLSegAgent): """ Training and testing agent for semi-supervised segmentation """ def __init__(self, config, stage = 'train'): - super(WSL_USTM, self).__init__(config, stage) + super(WSLUSTM, self).__init__(config, stage) self.net_ema = None def create_network(self): - super(WSL_USTM, self).create_network() + super(WSLUSTM, self).create_network() if(self.net_ema is None): net_name = self.config['network']['net_type'] if(net_name not in SegNetDict): From 5f99937ec783526dc29e04224cb5129099dc7654 Mon Sep 17 00:00:00 2001 From: taigw Date: Thu, 4 Aug 2022 21:39:09 +0800 Subject: [PATCH 08/26] update nll method rename nll method --- pymic/net_run_nll/{cl.py => nll_cl.py} | 6 +++--- .../{co_teaching.py => nll_co_teaching.py} | 18 ++++++++---------- 2 files changed, 11 insertions(+), 13 deletions(-) rename pymic/net_run_nll/{cl.py => nll_cl.py} (97%) rename pymic/net_run_nll/{co_teaching.py => nll_co_teaching.py} (94%) diff --git a/pymic/net_run_nll/cl.py b/pymic/net_run_nll/nll_cl.py similarity index 97% rename from pymic/net_run_nll/cl.py rename to pymic/net_run_nll/nll_cl.py index de31e3e..8792ccd 100644 --- a/pymic/net_run_nll/cl.py +++ b/pymic/net_run_nll/nll_cl.py @@ -45,9 +45,9 @@ def get_confident_map(gt, pred, CL_type = 'both'): noise = cleanlab.pruning.get_noise_indices(gt, prob, prune_method=CL_type, n_jobs=1) return noise -class SegmentationAgentwithCL(SegmentationAgent): +class NLLConfidentLeran(SegmentationAgent): def __init__(self, config, stage = 'test'): - super(SegmentationAgentwithCL, self).__init__(config, stage) + super(NLLConfidentLeran, self).__init__(config, stage) def infer_with_cl(self): device_ids = self.config['testing']['gpus'] @@ -179,7 +179,7 @@ def main(): with_label= True, transform = data_transform ) - agent = SegmentationAgentwithCL(config, 'test') + agent = NLLConfidentLeran(config, 'test') agent.set_datasets(None, None, dataset) agent.transform_list = transform_list agent.run() diff --git a/pymic/net_run_nll/co_teaching.py b/pymic/net_run_nll/nll_co_teaching.py similarity index 94% rename from pymic/net_run_nll/co_teaching.py rename to pymic/net_run_nll/nll_co_teaching.py index 228e3bd..e1392eb 100644 --- a/pymic/net_run_nll/co_teaching.py +++ b/pymic/net_run_nll/nll_co_teaching.py @@ -29,16 +29,14 @@ import sys from pymic.util.parse_config import * -class CoTeachingAgent(SegmentationAgent): +class NLLCoTeaching(SegmentationAgent): """ - Using cross pseudo supervision according to the following paper: - Xiaokang Chen, Yuhui Yuan, Gang Zeng, Jingdong Wang, - Semi-Supervised Semantic Segmentation with Cross Pseudo Supervision, - CVPR 2021, pp. 2613-2022. - https://arxiv.org/abs/2106.01226 + Co-teaching: Robust Training of Deep Neural Networks with Extremely + Noisy Labels + https://arxiv.org/abs/1804.06872 """ def __init__(self, config, stage = 'train'): - super(CoTeachingAgent, self).__init__(config, stage) + super(NLLCoTeaching, self).__init__(config, stage) self.net2 = None self.optimizer2 = None self.scheduler2 = None @@ -48,7 +46,7 @@ def __init__(self, config, stage = 'train'): " coteaching, the specified loss {0:} is ingored".format(loss_type)) def create_network(self): - super(CoTeachingAgent, self).create_network() + super(NLLCoTeaching, self).create_network() if(self.net2 is None): net_name = self.config['network']['net_type'] if(net_name not in SegNetDict): @@ -74,7 +72,7 @@ def train_valid(self): self.config['training']['lr_milestones'], self.config['training']['lr_gamma'], last_epoch = last_iter) - super(CoTeachingAgent, self).train_valid() + super(NLLCoTeaching, self).train_valid() def training(self): class_num = self.config['network']['class_num'] @@ -211,5 +209,5 @@ def write_scalars(self, train_scalars, valid_scalars, glob_it): format='%(message)s') logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) logging_config(config) - agent = CoTeachingAgent(config, stage) + agent = NLLCoTeaching(config, stage) agent.run() \ No newline at end of file From ee67842972f80841edcc5641fcfd9be0261e9ec9 Mon Sep 17 00:00:00 2001 From: taigw Date: Sat, 6 Aug 2022 12:48:53 +0800 Subject: [PATCH 09/26] update test-time augmentation and post process --- pymic/net_run/infer_func.py | 2 +- pymic/util/image_process.py | 12 ++++++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/pymic/net_run/infer_func.py b/pymic/net_run/infer_func.py index 35bfb4c..78184fe 100644 --- a/pymic/net_run/infer_func.py +++ b/pymic/net_run/infer_func.py @@ -140,7 +140,7 @@ def run(self, model, image): outputs4 = self.__infer(torch.flip(image, [-2, -1])) if(isinstance(outputs1, (tuple, list))): outputs = [] - for i in range(len(outputs)): + for i in range(len(outputs1)): temp_out1 = outputs1[i] temp_out2 = torch.flip(outputs2[i], [-2]) temp_out3 = torch.flip(outputs3[i], [-1]) diff --git a/pymic/util/image_process.py b/pymic/util/image_process.py index 61b577b..aa9f611 100644 --- a/pymic/util/image_process.py +++ b/pymic/util/image_process.py @@ -107,9 +107,9 @@ def crop_and_pad_ND_array_to_desired_shape(image, out_shape, pad_mod): return image_pad -def get_largest_component(image): +def get_largest_k_components(image, k = 1): """ - get the largest component from 2D or 3D binary image + get the largest K components from 2D or 3D binary image image: nd array """ dim = len(image.shape) @@ -124,8 +124,12 @@ def get_largest_component(image): raise ValueError("the dimension number should be 2 or 3") labeled_array, numpatches = ndimage.label(image, s) sizes = ndimage.sum(image, labeled_array, range(1, numpatches + 1)) - max_label = np.where(sizes == sizes.max())[0] + 1 - output = np.asarray(labeled_array == max_label, np.uint8) + sizes_sort = sorted(sizes, reverse = True) + kmin = min(k, numpatches) + output = np.zeros_like(image) + for i in range(kmin): + labeli = np.where(sizes == sizes_sort[i])[0] + 1 + output = output + np.asarray(labeled_array == labeli, np.uint8) return output def get_euclidean_distance(image, dim = 3, spacing = [1.0, 1.0, 1.0]): From 7235afd08f7e2e03bd164ad2a16e39f8c60c4696 Mon Sep 17 00:00:00 2001 From: taigw Date: Sat, 6 Aug 2022 14:04:36 +0800 Subject: [PATCH 10/26] add post process for inference --- pymic/net_run/agent_seg.py | 14 ++++++++++++- pymic/util/image_process.py | 7 ++----- pymic/util/post_process.py | 40 +++++++++++++++++++++++++++++++++++++ 3 files changed, 55 insertions(+), 6 deletions(-) create mode 100644 pymic/util/post_process.py diff --git a/pymic/net_run/agent_seg.py b/pymic/net_run/agent_seg.py index 0cc59b3..38c75ae 100644 --- a/pymic/net_run/agent_seg.py +++ b/pymic/net_run/agent_seg.py @@ -25,13 +25,16 @@ from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.transform.trans_dict import TransformDict +from pymic.util.post_process import PostProcessDict from pymic.util.image_process import convert_label from pymic.util.general import keyword_match class SegmentationAgent(NetRunAgent): def __init__(self, config, stage = 'train'): super(SegmentationAgent, self).__init__(config, stage) - self.transform_dict = TransformDict + self.transform_dict = TransformDict + self.postprocess_dict = PostProcessDict + self.postprocessor = None def get_stage_dataset_from_config(self, stage): assert(stage in ['train', 'valid', 'test']) @@ -155,6 +158,9 @@ def get_loss_value(self, data, pred, gt, param = None): loss_value = self.loss_calculator(loss_input_dict) return loss_value + def set_postprocessor(self, postprocessor): + self.postprocessor = postprocessor + def training(self): class_num = self.config['network']['class_num'] iter_valid = self.config['training']['iter_valid'] @@ -410,6 +416,9 @@ def test_time_dropout(m): infer_cfg = self.config['testing'] infer_cfg['class_num'] = self.config['network']['class_num'] self.inferer = Inferer(infer_cfg) + postpro_name = self.config['testing'].get('post_process', None) + if(self.postprocessor is None and postpro_name is not None): + self.postprocessor = PostProcessDict[postpro_name](self.config['testing']) infer_time_list = [] with torch.no_grad(): for data in self.test_loader: @@ -518,6 +527,9 @@ def save_ouputs(self, data): output = np.asarray(np.argmax(prob, axis = 1), np.uint8) if((label_source is not None) and (label_target is not None)): output = convert_label(output, label_source, label_target) + if(self.postprocessor is not None): + for i in range(len(names)): + output[i] = self.postprocessor(output[i]) # save the output and (optionally) probability predictions root_dir = self.config['dataset']['root_dir'] for i in range(len(names)): diff --git a/pymic/util/image_process.py b/pymic/util/image_process.py index aa9f611..896e8c1 100644 --- a/pymic/util/image_process.py +++ b/pymic/util/image_process.py @@ -116,12 +116,9 @@ def get_largest_k_components(image, k = 1): if(image.sum() == 0 ): print('the largest component is null') return image - if(dim == 2): - s = ndimage.generate_binary_structure(2,1) - elif(dim == 3): - s = ndimage.generate_binary_structure(3,1) - else: + if(dim < 2 or dim > 3): raise ValueError("the dimension number should be 2 or 3") + s = ndimage.generate_binary_structure(dim,1) labeled_array, numpatches = ndimage.label(image, s) sizes = ndimage.sum(image, labeled_array, range(1, numpatches + 1)) sizes_sort = sorted(sizes, reverse = True) diff --git a/pymic/util/post_process.py b/pymic/util/post_process.py new file mode 100644 index 0000000..da133ca --- /dev/null +++ b/pymic/util/post_process.py @@ -0,0 +1,40 @@ +# -*- coding: utf-8 -*- +from __future__ import absolute_import, print_function + +import os +import numpy as np +import SimpleITK as sitk +from pymic.util.image_process import get_largest_k_components + +class PostProcess(object): + def __init__(self, params): + self.params = params + + def __call__(self, seg): + return seg + +class PostKeepLargestComponent(PostProcess): + def __init__(self, params): + super(PostKeepLargestComponent, self).__init__(params) + self.mode = params.get("KeepLargestComponent_mode".lower(), 1) + """ + mode = 1: keep the largest component of the union of foreground classes. + mode = 2: keep the largest component for each foreground class. + """ + + def __call__(self, seg): + if(self.mode == 1): + mask = np.asarray(seg > 0, np.uint8) + mask = get_largest_k_components(mask) + seg = seg * mask + elif(self.mode == 2): + class_num = seg.max() + output = np.zeros_like(seg) + for c in range(1, class_num + 1): + seg_c = np.asarray(seg == c, np.uint8) + seg_c = get_largest_k_components(seg_c) + output = output + seg_c * c + return seg + +PostProcessDict = { + 'KeepLargestComponent': PostKeepLargestComponent} \ No newline at end of file From 03d99646aa665423e8f3c0e260a22b7423b758f8 Mon Sep 17 00:00:00 2001 From: taigw Date: Mon, 8 Aug 2022 16:49:14 +0800 Subject: [PATCH 11/26] update ssl_em according to agent_seg, allow recording lr during training, and support ReduceLROnPlateau --- pymic/net_run_ssl/ssl_abstract.py | 7 ++++--- pymic/net_run_ssl/ssl_cps.py | 4 ++-- pymic/net_run_ssl/ssl_em.py | 4 +++- pymic/util/evaluation_seg.py | 6 +++--- 4 files changed, 12 insertions(+), 9 deletions(-) diff --git a/pymic/net_run_ssl/ssl_abstract.py b/pymic/net_run_ssl/ssl_abstract.py index acb3b5a..f1b97ba 100644 --- a/pymic/net_run_ssl/ssl_abstract.py +++ b/pymic/net_run_ssl/ssl_abstract.py @@ -28,7 +28,7 @@ def __init__(self, config, stage = 'train'): def get_unlabeled_dataset_from_config(self): root_dir = self.config['dataset']['root_dir'] - modal_num = self.config['dataset']['modal_num'] + modal_num = self.config['dataset'].get('modal_num', 1) transform_names = self.config['dataset']['train_transform_unlab'] self.transform_list = [] @@ -72,8 +72,8 @@ def worker_init_fn(worker_id): def training(self): pass - - def write_scalars(self, train_scalars, valid_scalars, glob_it): + + def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): loss_scalar ={'train':train_scalars['loss'], 'valid':valid_scalars['loss']} loss_sup_scalar = {'train':train_scalars['loss_sup']} @@ -83,6 +83,7 @@ def write_scalars(self, train_scalars, valid_scalars, glob_it): self.summ_writer.add_scalars('loss_sup', loss_sup_scalar, glob_it) self.summ_writer.add_scalars('loss_reg', loss_upsup_scalar, glob_it) self.summ_writer.add_scalars('regular_w', {'regular_w':train_scalars['regular_w']}, glob_it) + self.summ_writer.add_scalars('lr', {"lr": lr_value}, glob_it) self.summ_writer.add_scalars('dice', dice_scalar, glob_it) class_num = self.config['network']['class_num'] for c in range(class_num): diff --git a/pymic/net_run_ssl/ssl_cps.py b/pymic/net_run_ssl/ssl_cps.py index 49a6e11..a878ea0 100644 --- a/pymic/net_run_ssl/ssl_cps.py +++ b/pymic/net_run_ssl/ssl_cps.py @@ -8,7 +8,7 @@ from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.util.ramps import sigmoid_rampup -from pymic.net_run.get_optimizer import get_optimiser +from pymic.net_run.get_optimizer import get_optimizer from pymic.net_run_ssl.ssl_abstract import SSLSegAgent from pymic.net.net_dict_seg import SegNetDict @@ -41,7 +41,7 @@ def create_network(self): def train_valid(self): # create optimizor for the second network if(self.optimizer2 is None): - self.optimizer2 = get_optimiser(self.config['training']['optimizer'], + self.optimizer2 = get_optimizer(self.config['training']['optimizer'], self.net2.parameters(), self.config['training']) last_iter = -1 diff --git a/pymic/net_run_ssl/ssl_em.py b/pymic/net_run_ssl/ssl_em.py index 28a9dec..bb1cd55 100644 --- a/pymic/net_run_ssl/ssl_em.py +++ b/pymic/net_run_ssl/ssl_em.py @@ -10,6 +10,7 @@ from pymic.net_run_ssl.ssl_abstract import SSLSegAgent from pymic.transform.trans_dict import TransformDict from pymic.util.ramps import sigmoid_rampup +from pymic.util.general import keyword_match class SSLEntropyMinimization(SSLSegAgent): """ @@ -73,7 +74,8 @@ def training(self): # if (self.config['training']['use']) loss.backward() self.optimizer.step() - self.scheduler.step() + if(not keyword_match(self.config['training']['lr_scheduler'], "ReduceLROnPlateau")): + self.scheduler.step() train_loss = train_loss + loss.item() train_loss_sup = train_loss_sup + loss_sup.item() diff --git a/pymic/util/evaluation_seg.py b/pymic/util/evaluation_seg.py index 61ae51c..b04880a 100644 --- a/pymic/util/evaluation_seg.py +++ b/pymic/util/evaluation_seg.py @@ -74,7 +74,7 @@ def get_edge_points(img): return edge -def binary_hausdorff95(s, g, spacing = None): +def binary_hd95(s, g, spacing = None): """ get the hausdorff distance between a binary segmentation and the ground truth inputs: @@ -165,8 +165,8 @@ def get_binary_evaluation_score(s_volume, g_volume, spacing, metric): elif(metric_lower == 'assd'): score = binary_assd(s_volume, g_volume, spacing) - elif(metric_lower == "hausdorff95"): - score = binary_hausdorff95(s_volume, g_volume, spacing) + elif(metric_lower == "hd95"): + score = binary_hd95(s_volume, g_volume, spacing) elif(metric_lower == "rve"): score = binary_relative_volume_error(s_volume, g_volume) From e87c0fe59adbeca0eb8b0c891dd488ff0893e0e6 Mon Sep 17 00:00:00 2001 From: taigw Date: Tue, 9 Aug 2022 11:54:57 +0800 Subject: [PATCH 12/26] update ssl method add lr scheduler --- pymic/net_run_ssl/ssl_cps.py | 36 ++++++++++++++++++++--------------- pymic/net_run_ssl/ssl_main.py | 4 ++-- pymic/net_run_ssl/ssl_mt.py | 6 ++++-- pymic/net_run_ssl/ssl_uamt.py | 7 +++++-- pymic/net_run_ssl/ssl_urpc.py | 7 ++++--- 5 files changed, 36 insertions(+), 24 deletions(-) diff --git a/pymic/net_run_ssl/ssl_cps.py b/pymic/net_run_ssl/ssl_cps.py index a878ea0..582cdc0 100644 --- a/pymic/net_run_ssl/ssl_cps.py +++ b/pymic/net_run_ssl/ssl_cps.py @@ -7,12 +7,13 @@ from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice -from pymic.util.ramps import sigmoid_rampup -from pymic.net_run.get_optimizer import get_optimizer +from pymic.net_run.get_optimizer import get_optimizer, get_lr_scheduler from pymic.net_run_ssl.ssl_abstract import SSLSegAgent from pymic.net.net_dict_seg import SegNetDict +from pymic.util.ramps import sigmoid_rampup +from pymic.util.general import keyword_match -class SSLCrossPseudoSupervision(SSLSegAgent): +class SSLCPS(SSLSegAgent): """ Using cross pseudo supervision according to the following paper: Xiaokang Chen, Yuhui Yuan, Gang Zeng, Jingdong Wang, @@ -21,13 +22,13 @@ class SSLCrossPseudoSupervision(SSLSegAgent): https://arxiv.org/abs/2106.01226 """ def __init__(self, config, stage = 'train'): - super(SSLCrossPseudoSupervision, self).__init__(config, stage) + super(SSLCPS, self).__init__(config, stage) self.net2 = None self.optimizer2 = None self.scheduler2 = None def create_network(self): - super(SSLCrossPseudoSupervision, self).create_network() + super(SSLCPS, self).create_network() if(self.net2 is None): net_name = self.config['network']['net_type'] if(net_name not in SegNetDict): @@ -40,20 +41,18 @@ def create_network(self): def train_valid(self): # create optimizor for the second network + opt_params = self.config['training'] if(self.optimizer2 is None): - self.optimizer2 = get_optimizer(self.config['training']['optimizer'], - self.net2.parameters(), - self.config['training']) + self.optimizer2 = get_optimizer(opt_params['optimizer'], + self.net2.parameters(), opt_params) last_iter = -1 # if(self.checkpoint is not None): # self.optimizer2.load_state_dict(self.checkpoint['optimizer_state_dict']) # last_iter = self.checkpoint['iteration'] - 1 if(self.scheduler2 is None): - self.scheduler2 = optim.lr_scheduler.MultiStepLR(self.optimizer2, - self.config['training']['lr_milestones'], - self.config['training']['lr_gamma'], - last_epoch = last_iter) - super(SSLCrossPseudoSupervision, self).train_valid() + opt_params["laster_iter"] = last_iter + self.scheduler2 = get_lr_scheduler(self.optimizer, opt_params) + super(SSLCPS, self).train_valid() def training(self): class_num = self.config['network']['class_num'] @@ -121,9 +120,10 @@ def training(self): loss.backward() self.optimizer.step() - self.scheduler.step() self.optimizer2.step() - self.scheduler2.step() + if(not keyword_match(self.config['training']['lr_scheduler'], "ReduceLROnPlateau")): + self.scheduler.step() + self.scheduler2.step() train_loss = train_loss + loss.item() train_loss_sup1 = train_loss_sup1 + loss_sup1.item() @@ -152,6 +152,12 @@ def training(self): 'loss_pse_sup1':train_avg_loss_pse_sup1, 'loss_pse_sup2': train_avg_loss_pse_sup2, 'regular_w':regular_w, 'avg_dice':train_avg_dice, 'class_dice': train_cls_dice} return train_scalers + + def validation(self): + return_value = super(SSLCPS, self).validation() + if(keyword_match(self.config['training']['lr_scheduler'], "ReduceLROnPlateau")): + self.scheduler2.step(return_value['avg_dice']) + return return_value def write_scalars(self, train_scalars, valid_scalars, glob_it): loss_scalar ={'train':train_scalars['loss'], diff --git a/pymic/net_run_ssl/ssl_main.py b/pymic/net_run_ssl/ssl_main.py index cf5a8cd..54492ae 100644 --- a/pymic/net_run_ssl/ssl_main.py +++ b/pymic/net_run_ssl/ssl_main.py @@ -9,13 +9,13 @@ from pymic.net_run_ssl.ssl_mt import SSLMeanTeacher from pymic.net_run_ssl.ssl_uamt import SSLUncertaintyAwareMeanTeacher from pymic.net_run_ssl.ssl_urpc import SSLURPC -from pymic.net_run_ssl.ssl_cps import SSLCrossPseudoSupervision +from pymic.net_run_ssl.ssl_cps import SSLCPS SSLMethodDict = {'EntropyMinimization': SSLEntropyMinimization, 'MeanTeacher': SSLMeanTeacher, 'UAMT': SSLUncertaintyAwareMeanTeacher, 'URPC': SSLURPC, - 'CPS': SSLCrossPseudoSupervision} + 'CPS': SSLCPS} def main(): if(len(sys.argv) < 3): diff --git a/pymic/net_run_ssl/ssl_mt.py b/pymic/net_run_ssl/ssl_mt.py index b96e0b1..7905968 100644 --- a/pymic/net_run_ssl/ssl_mt.py +++ b/pymic/net_run_ssl/ssl_mt.py @@ -6,9 +6,10 @@ from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice -from pymic.util.ramps import sigmoid_rampup from pymic.net_run_ssl.ssl_abstract import SSLSegAgent from pymic.net.net_dict_seg import SegNetDict +from pymic.util.ramps import sigmoid_rampup +from pymic.util.general import keyword_match class SSLMeanTeacher(SSLSegAgent): """ @@ -89,7 +90,8 @@ def training(self): loss.backward() self.optimizer.step() - self.scheduler.step() + if(not keyword_match(self.config['training']['lr_scheduler'], "ReduceLROnPlateau")): + self.scheduler.step() # update EMA alpha = ssl_cfg.get('ema_decay', 0.99) diff --git a/pymic/net_run_ssl/ssl_uamt.py b/pymic/net_run_ssl/ssl_uamt.py index d1de32f..3352231 100644 --- a/pymic/net_run_ssl/ssl_uamt.py +++ b/pymic/net_run_ssl/ssl_uamt.py @@ -6,8 +6,9 @@ from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice -from pymic.util.ramps import sigmoid_rampup from pymic.net_run_ssl.ssl_mt import SSLMeanTeacher +from pymic.util.ramps import sigmoid_rampup +from pymic.util.general import keyword_match class SSLUncertaintyAwareMeanTeacher(SSLMeanTeacher): """ @@ -97,7 +98,9 @@ def training(self): loss.backward() self.optimizer.step() - self.scheduler.step() + if(not keyword_match(self.config['training']['lr_scheduler'], "ReduceLROnPlateau")): + self.scheduler.step() + # update EMA alpha = ssl_cfg.get('ema_decay', 0.99) diff --git a/pymic/net_run_ssl/ssl_urpc.py b/pymic/net_run_ssl/ssl_urpc.py index d9dd953..0513dc2 100644 --- a/pymic/net_run_ssl/ssl_urpc.py +++ b/pymic/net_run_ssl/ssl_urpc.py @@ -7,8 +7,9 @@ from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice -from pymic.util.ramps import sigmoid_rampup from pymic.net_run_ssl.ssl_abstract import SSLSegAgent +from pymic.util.ramps import sigmoid_rampup +from pymic.util.general import keyword_match class SSLURPC(SSLSegAgent): """ @@ -90,8 +91,8 @@ def training(self): loss.backward() self.optimizer.step() - self.scheduler.step() - + if(not keyword_match(self.config['training']['lr_scheduler'], "ReduceLROnPlateau")): + self.scheduler.step() train_loss = train_loss + loss.item() train_loss_sup = train_loss_sup + loss_sup.item() train_loss_reg = train_loss_reg + loss_reg.item() From e057152fe4d6aaffc933c684bd26925ddfd3a95e Mon Sep 17 00:00:00 2001 From: taigw Date: Tue, 9 Aug 2022 17:27:27 +0800 Subject: [PATCH 13/26] update network and ssl cct allow dropout in the decoder add CCT for SSL --- pymic/net/net2d/unet2d.py | 8 +- pymic/net/net2d/unet2d_cct.py | 195 ++++++++++++++++++++++++++++++++ pymic/net/net2d/unet2d_scse.py | 8 +- pymic/net/net3d/unet2d5.py | 8 +- pymic/net/net3d/unet3d.py | 8 +- pymic/net/net3d/unet3d_scse.py | 8 +- pymic/net/net_dict_seg.py | 2 + pymic/net_run/agent_abstract.py | 2 +- pymic/net_run_ssl/ssl_cct.py | 156 +++++++++++++++++++++++++ pymic/net_run_ssl/ssl_cps.py | 5 +- pymic/net_run_ssl/ssl_main.py | 9 +- 11 files changed, 383 insertions(+), 26 deletions(-) create mode 100644 pymic/net/net2d/unet2d_cct.py create mode 100644 pymic/net_run_ssl/ssl_cct.py diff --git a/pymic/net/net2d/unet2d.py b/pymic/net/net2d/unet2d.py index 0cc607f..703ced3 100644 --- a/pymic/net/net2d/unet2d.py +++ b/pymic/net/net2d/unet2d.py @@ -91,10 +91,10 @@ def __init__(self, params): self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dropout[3]) if(len(self.ft_chns) == 5): self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4]) - self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], 0.0, self.bilinear) - self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], 0.0, self.bilinear) - self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], 0.0, self.bilinear) - self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], 0.0, self.bilinear) + self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], self.dropout[3], self.bilinear) + self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], self.dropout[2], self.bilinear) + self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.bilinear) + self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.bilinear) self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, kernel_size = 1) if(self.deep_sup): diff --git a/pymic/net/net2d/unet2d_cct.py b/pymic/net/net2d/unet2d_cct.py new file mode 100644 index 0000000..88a369f --- /dev/null +++ b/pymic/net/net2d/unet2d_cct.py @@ -0,0 +1,195 @@ +# -*- coding: utf-8 -*- +""" +An modification the U-Net with auxiliary decoders according to +the CCT paper: + Yassine Ouali, Celine Hudelot and Myriam Tami: + Semi-Supervised Semantic Segmentation With Cross-Consistency Training. + CVPR 2020. + https://arxiv.org/abs/2003.09005 +Code adapted from: https://github.com/yassouali/CCT +""" +from __future__ import print_function, division + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from torch.distributions.uniform import Uniform +from pymic.net.net2d.unet2d import ConvBlock, DownBlock, UpBlock + +class Encoder(nn.Module): + def __init__(self, params): + super(Encoder, self).__init__() + self.params = params + self.in_chns = self.params['in_chns'] + self.ft_chns = self.params['feature_chns'] + self.dropout = self.params['dropout'] + assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) + + self.in_conv= ConvBlock(self.in_chns, self.ft_chns[0], self.dropout[0]) + self.down1 = DownBlock(self.ft_chns[0], self.ft_chns[1], self.dropout[1]) + self.down2 = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dropout[2]) + self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dropout[3]) + if(len(self.ft_chns) == 5): + self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4]) + + def forward(self, x): + x0 = self.in_conv(x) + x1 = self.down1(x0) + x2 = self.down2(x1) + x3 = self.down3(x2) + output = [x0, x1, x2, x3] + if(len(self.ft_chns) == 5): + x4 = self.down4(x3) + output.append(x4) + return output + +class Decoder(nn.Module): + def __init__(self, params): + super(Decoder, self).__init__() + self.params = params + self.in_chns = self.params['in_chns'] + self.ft_chns = self.params['feature_chns'] + self.dropout = self.params['dropout'] + self.n_class = self.params['class_num'] + self.bilinear = self.params['bilinear'] + + assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) + + if(len(self.ft_chns) == 5): + self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], self.dropout[3], self.bilinear) + self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], self.dropout[2], self.bilinear) + self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.bilinear) + self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.bilinear) + self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, kernel_size = 1) + + def forward(self, x): + if(len(self.ft_chns) == 5): + assert(len(x) == 5) + x0, x1, x2, x3, x4 = x + x_d3 = self.up1(x4, x3) + else: + assert(len(x) == 4) + x0, x1, x2, x3 = x + x_d3 = x3 + x_d2 = self.up2(x_d3, x2) + x_d1 = self.up3(x_d2, x1) + x_d0 = self.up4(x_d1, x0) + output = self.out_conv(x_d0) + return output + +def _l2_normalize(d): + # Normalizing per batch axis + d_reshaped = d.view(d.shape[0], -1, *(1 for _ in range(d.dim() - 2))) + d /= torch.norm(d_reshaped, dim=1, keepdim=True) + 1e-8 + return d + + + +def get_r_adv(x_list, decoder, it=1, xi=1e-1, eps=10.0): + """ + Virtual Adversarial Training according to + https://arxiv.org/abs/1704.03976 + """ + x_detached = [item.detach() for item in x_list] + xe_detached = x_detached[-1] + with torch.no_grad(): + pred = F.softmax(decoder(x_detached), dim=1) + + d = torch.rand(x_list[-1].shape).sub(0.5).to(x_list[-1].device) + d = _l2_normalize(d) + + for _ in range(it): + d.requires_grad_() + x_detached[-1] = xe_detached + xi * d + pred_hat = decoder(x_detached) + logp_hat = F.log_softmax(pred_hat, dim=1) + adv_distance = F.kl_div(logp_hat, pred, reduction='batchmean') + adv_distance.backward() + d = _l2_normalize(d.grad) + decoder.zero_grad() + + r_adv = d * eps + return x_list[-1] + r_adv + + +class AuxiliaryDecoder(nn.Module): + def __init__(self, params, aux_type): + super(AuxiliaryDecoder, self).__init__() + self.params = params + self.decoder = Decoder(params) + self.aux_type = aux_type + uniform_range = params.get("Uniform_range".lower(), 0.3) + self.uni_dist = Uniform(-uniform_range, uniform_range) + + def feature_drop(self, x): + attention = torch.mean(x, dim=1, keepdim=True) + max_val, _ = torch.max(attention.view(x.size(0), -1), dim=1, keepdim=True) + threshold = max_val * np.random.uniform(0.7, 0.9) + threshold = threshold.view(x.size(0), 1, 1, 1).expand_as(attention) + drop_mask = (attention < threshold).float() + return x.mul(drop_mask) + + def feature_based_noise(self, x): + noise_vector = self.uni_dist.sample(x.shape[1:]).to(x.device).unsqueeze(0) + x_noise = x.mul(noise_vector) + x + return x_noise + + def forward(self, x): + if(self.aux_type == "DropOut"): + pass + elif(self.aux_type == "FeatureDrop"): + x[-1] = self.feature_drop(x[-1]) + elif(self.aux_type == "FeatureNoise"): + x[-1] = self.feature_based_noise(x[-1]) + elif(self.aux_type == "VAT"): + it = self.params.get("VAT_it".lower(), 2) + xi = self.params.get("VAT_xi".lower(), 1e-6) + eps= self.params.get("VAT_eps".lower(), 2.0) + x[-1] = get_r_adv(x, self.decoder, it, xi, eps) + else: + raise ValueError("Undefined auxiliary decoder type {0:}".format(self.aux_type)) + + output = self.decoder(x) + return output + + +class UNet2D_CCT(nn.Module): + def __init__(self, params): + super(UNet2D_CCT, self).__init__() + self.params = params + self.encoder = Encoder(params) + self.decoder = Decoder(params) + aux_names = params.get("CCT_aux_decoders".lower(), None) + if aux_names is None: + aux_names = ["DropOut", "FeatureDrop", "FeatureNoise", "VAT"] + aux_decoders = [] + for aux_name in aux_names: + aux_decoders.append(AuxiliaryDecoder(params, aux_name)) + self.aux_decoders = nn.ModuleList(aux_decoders) + + + def forward(self, x): + x_shape = list(x.shape) + if(len(x_shape) == 5): + [N, C, D, H, W] = x_shape + new_shape = [N*D, C, H, W] + x = torch.transpose(x, 1, 2) + x = torch.reshape(x, new_shape) + + f = self.encoder(x) + output = self.decoder(f) + if(len(x_shape) == 5): + new_shape = [N, D] + list(output.shape)[1:] + output = torch.reshape(output, new_shape) + output = torch.transpose(output, 1, 2) + + if(self.training): + aux_outputs = [aux_d(f) for aux_d in self.aux_decoders] + if(len(x_shape) == 5): + for i in range(len(aux_outputs)): + aux_outi = torch.reshape(aux_outputs[i], new_shape) + aux_outputs[i] = torch.transpose(aux_outi, 1, 2) + return output, aux_outputs + else: + return output \ No newline at end of file diff --git a/pymic/net/net2d/unet2d_scse.py b/pymic/net/net2d/unet2d_scse.py index 95b25b1..6f9d3f7 100644 --- a/pymic/net/net2d/unet2d_scse.py +++ b/pymic/net/net2d/unet2d_scse.py @@ -79,10 +79,10 @@ def __init__(self, params): self.down2 = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dropout[2]) self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dropout[3]) self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4]) - self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], dropout_p = 0.0) - self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], dropout_p = 0.0) - self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], dropout_p = 0.0) - self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], dropout_p = 0.0) + self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], dropout_p = self.dropout[3]) + self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], dropout_p = self.dropout[2]) + self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], dropout_p = self.dropout[1]) + self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], dropout_p = self.dropout[0]) self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, kernel_size = 3, padding = 1) diff --git a/pymic/net/net3d/unet2d5.py b/pymic/net/net3d/unet2d5.py index 4ed8d7e..9e6a72d 100644 --- a/pymic/net/net3d/unet2d5.py +++ b/pymic/net/net3d/unet2d5.py @@ -149,13 +149,13 @@ def __init__(self, params): self.block3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dims[3], self.dropout[3], True) self.block4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dims[4], self.dropout[4], False) self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], - self.dims[3], dropout_p = 0.0, bilinear = self.bilinear) + self.dims[3], dropout_p = self.dropout[3], bilinear = self.bilinear) self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], - self.dims[2], dropout_p = 0.0, bilinear = self.bilinear) + self.dims[2], dropout_p = self.dropout[2], bilinear = self.bilinear) self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], - self.dims[1], dropout_p = 0.0, bilinear = self.bilinear) + self.dims[1], dropout_p = self.dropout[1], bilinear = self.bilinear) self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], - self.dims[0], dropout_p = 0.0, bilinear = self.bilinear) + self.dims[0], dropout_p = self.dropout[0], bilinear = self.bilinear) self.out_conv = nn.Conv3d(self.ft_chns[0], self.n_class, kernel_size = (1, 3, 3), padding = (0, 1, 1)) diff --git a/pymic/net/net3d/unet3d.py b/pymic/net/net3d/unet3d.py index a37204e..058cb79 100644 --- a/pymic/net/net3d/unet3d.py +++ b/pymic/net/net3d/unet3d.py @@ -106,13 +106,13 @@ def __init__(self, params): if(len(self.ft_chns) == 5): self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4]) self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], - dropout_p = 0.0, trilinear=self.trilinear) + dropout_p = self.dropout[3], trilinear=self.trilinear) self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], - dropout_p = 0.0, trilinear=self.trilinear) + dropout_p = self.dropout[2], trilinear=self.trilinear) self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], - dropout_p = 0.0, trilinear=self.trilinear) + dropout_p = self.dropout[1], trilinear=self.trilinear) self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], - dropout_p = 0.0, trilinear=self.trilinear) + dropout_p = self.dropout[0], trilinear=self.trilinear) self.out_conv = nn.Conv3d(self.ft_chns[0], self.n_class, kernel_size = 1) if(self.deep_sup): diff --git a/pymic/net/net3d/unet3d_scse.py b/pymic/net/net3d/unet3d_scse.py index 5832830..0f15e25 100644 --- a/pymic/net/net3d/unet3d_scse.py +++ b/pymic/net/net3d/unet3d_scse.py @@ -78,10 +78,10 @@ def __init__(self, params): self.down2 = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dropout[2]) self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dropout[3]) self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4]) - self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], dropout_p = 0.0) - self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], dropout_p = 0.0) - self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], dropout_p = 0.0) - self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], dropout_p = 0.0) + self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], dropout_p = self.dropout[3]) + self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], dropout_p = self.dropout[2]) + self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], dropout_p = self.dropout[1]) + self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], dropout_p = self.dropout[0]) self.out_conv = nn.Conv3d(self.ft_chns[0], self.n_class, kernel_size = 3, padding = 1) diff --git a/pymic/net/net_dict_seg.py b/pymic/net/net_dict_seg.py index 55711d4..aa912bc 100644 --- a/pymic/net/net_dict_seg.py +++ b/pymic/net/net_dict_seg.py @@ -3,6 +3,7 @@ from pymic.net.net2d.unet2d import UNet2D from pymic.net.net2d.unet2d_dual_branch import DualBranchUNet2D from pymic.net.net2d.unet2d_urpc import UNet2D_URPC +from pymic.net.net2d.unet2d_cct import UNet2D_CCT from pymic.net.net2d.cople_net import COPLENet from pymic.net.net2d.unet2d_attention import AttentionUNet2D from pymic.net.net2d.unet2d_nest import NestedUNet2D @@ -15,6 +16,7 @@ 'UNet2D': UNet2D, 'DualBranchUNet2D': DualBranchUNet2D, 'UNet2D_URPC': UNet2D_URPC, + 'UNet2D_CCT': UNet2D_CCT, 'COPLENet': COPLENet, 'AttentionUNet2D': AttentionUNet2D, 'NestedUNet2D': NestedUNet2D, diff --git a/pymic/net_run/agent_abstract.py b/pymic/net_run/agent_abstract.py index d701534..8ffadc1 100644 --- a/pymic/net_run/agent_abstract.py +++ b/pymic/net_run/agent_abstract.py @@ -157,7 +157,7 @@ def create_optimizer(self, params): self.optimizer.load_state_dict(self.checkpoint['optimizer_state_dict']) last_iter = self.checkpoint['iteration'] - 1 if(self.scheduler is None): - opt_params["laster_iter"] = last_iter + opt_params["last_iter"] = last_iter self.scheduler = get_lr_scheduler(self.optimizer, opt_params) def convert_tensor_type(self, input_tensor): diff --git a/pymic/net_run_ssl/ssl_cct.py b/pymic/net_run_ssl/ssl_cct.py new file mode 100644 index 0000000..80e9d07 --- /dev/null +++ b/pymic/net_run_ssl/ssl_cct.py @@ -0,0 +1,156 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division +import logging +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from pymic.loss.seg.util import get_soft_label +from pymic.loss.seg.util import reshape_prediction_and_ground_truth +from pymic.loss.seg.util import get_classwise_dice +from pymic.net_run_ssl.ssl_abstract import SSLSegAgent +from pymic.util.ramps import sigmoid_rampup +from pymic.util.general import keyword_match + +def softmax_mse_loss(inputs, targets, conf_mask=False, threshold=None, use_softmax=False): + assert inputs.requires_grad == True and targets.requires_grad == False + assert inputs.size() == targets.size() # (batch_size * num_classes * H * W) + inputs = F.softmax(inputs, dim=1) + if use_softmax: + targets = F.softmax(targets, dim=1) + + if conf_mask: + loss_mat = F.mse_loss(inputs, targets, reduction='none') + mask = (targets.max(1)[0] > threshold) + loss_mat = loss_mat[mask.unsqueeze(1).expand_as(loss_mat)] + if loss_mat.shape.numel() == 0: loss_mat = torch.tensor([0.]).to(inputs.device) + return loss_mat.mean() + else: + return F.mse_loss(inputs, targets, reduction='mean') # take the mean over the batch_size + + +def softmax_kl_loss(inputs, targets, conf_mask=False, threshold=None, use_softmax=False): + assert inputs.requires_grad == True and targets.requires_grad == False + assert inputs.size() == targets.size() + input_log_softmax = F.log_softmax(inputs, dim=1) + if use_softmax: + targets = F.softmax(targets, dim=1) + + if conf_mask: + loss_mat = F.kl_div(input_log_softmax, targets, reduction='none') + mask = (targets.max(1)[0] > threshold) + loss_mat = loss_mat[mask.unsqueeze(1).expand_as(loss_mat)] + if loss_mat.shape.numel() == 0: loss_mat = torch.tensor([0.]).to(inputs.device) + return loss_mat.sum() / mask.shape.numel() + else: + return F.kl_div(input_log_softmax, targets, reduction='mean') + + +def softmax_js_loss(inputs, targets, **_): + assert inputs.requires_grad == True and targets.requires_grad == False + assert inputs.size() == targets.size() + epsilon = 1e-5 + + M = (F.softmax(inputs, dim=1) + targets) * 0.5 + kl1 = F.kl_div(F.log_softmax(inputs, dim=1), M, reduction='mean') + kl2 = F.kl_div(torch.log(targets+epsilon), M, reduction='mean') + return (kl1 + kl2) * 0.5 + +unsup_loss_dict = {"MSE": softmax_mse_loss, + "KL":softmax_kl_loss, + "JS":softmax_js_loss} + +class SSLCCT(SSLSegAgent): + """ + Cross-Consistency Training according to the following paper: + Yassine Ouali, Celine Hudelot and Myriam Tami: + Semi-Supervised Semantic Segmentation With Cross-Consistency Training. + CVPR 2020. + https://arxiv.org/abs/2003.09005 + Code adapted from: https://github.com/yassouali/CCT + """ + def training(self): + class_num = self.config['network']['class_num'] + iter_valid = self.config['training']['iter_valid'] + ssl_cfg = self.config['semi_supervised_learning'] + unsup_loss_name = ssl_cfg.get('unsupervised_loss', "MSE") + self.unsup_loss_f = unsup_loss_dict[unsup_loss_name] + train_loss = 0 + train_loss_sup = 0 + train_loss_reg = 0 + train_dice_list = [] + self.net.train() + + for it in range(iter_valid): + try: + data_lab = next(self.trainIter) + except StopIteration: + self.trainIter = iter(self.train_loader) + data_lab = next(self.trainIter) + try: + data_unlab = next(self.trainIter_unlab) + except StopIteration: + self.trainIter_unlab = iter(self.train_loader_unlab) + data_unlab = next(self.trainIter_unlab) + + # get the inputs + x0 = self.convert_tensor_type(data_lab['image']) + y0 = self.convert_tensor_type(data_lab['label_prob']) + x1 = self.convert_tensor_type(data_unlab['image']) + inputs = torch.cat([x0, x1], dim = 0) + inputs, y0 = inputs.to(self.device), y0.to(self.device) + + # zero the parameter gradients + self.optimizer.zero_grad() + + # forward pass + output, aux_outputs = self.net(inputs) + n0 = list(x0.shape)[0] + + # get supervised loss + p0 = output[:n0] + loss_sup = self.get_loss_value(data_lab, p0, y0) + + # get regularization loss + p1 = F.softmax(output[n0:].detach(), dim=1) + p1_aux = [aux_out[n0:] for aux_out in aux_outputs] + loss_reg = 0.0 + for p1_auxi in p1_aux: + loss_reg += self.unsup_loss_f( p1_auxi, p1, use_softmax = True) + loss_reg = loss_reg / len(p1_aux) + + iter_max = self.config['training']['iter_max'] + ramp_up_length = ssl_cfg.get('ramp_up_length', iter_max) + regular_w = 0.0 + if(self.glob_it > ssl_cfg.get('iter_sup', 0)): + regular_w = ssl_cfg.get('regularize_w', 0.1) + if(ramp_up_length is not None and self.glob_it < ramp_up_length): + regular_w = regular_w * sigmoid_rampup(self.glob_it, ramp_up_length) + + loss = loss_sup + regular_w*loss_reg + + loss.backward() + self.optimizer.step() + if(not keyword_match(self.config['training']['lr_scheduler'], "ReduceLROnPlateau")): + self.scheduler.step() + train_loss = train_loss + loss.item() + train_loss_sup = train_loss_sup + loss_sup.item() + train_loss_reg = train_loss_reg + loss_reg.item() + # get dice evaluation for each class in annotated images + if(isinstance(p0, tuple) or isinstance(p0, list)): + p0 = p0[0] + p0_argmax = torch.argmax(p0, dim = 1, keepdim = True) + p0_soft = get_soft_label(p0_argmax, class_num, self.tensor_type) + p0_soft, y0 = reshape_prediction_and_ground_truth(p0_soft, y0) + dice_list = get_classwise_dice(p0_soft, y0) + train_dice_list.append(dice_list.cpu().numpy()) + train_avg_loss = train_loss / iter_valid + train_avg_loss_sup = train_loss_sup / iter_valid + train_avg_loss_reg = train_loss_reg / iter_valid + train_cls_dice = np.asarray(train_dice_list).mean(axis = 0) + train_avg_dice = train_cls_dice.mean() + + train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, + 'loss_reg':train_avg_loss_reg, 'regular_w':regular_w, + 'avg_dice':train_avg_dice, 'class_dice': train_cls_dice} + return train_scalers diff --git a/pymic/net_run_ssl/ssl_cps.py b/pymic/net_run_ssl/ssl_cps.py index 582cdc0..e39bec9 100644 --- a/pymic/net_run_ssl/ssl_cps.py +++ b/pymic/net_run_ssl/ssl_cps.py @@ -51,7 +51,7 @@ def train_valid(self): # last_iter = self.checkpoint['iteration'] - 1 if(self.scheduler2 is None): opt_params["laster_iter"] = last_iter - self.scheduler2 = get_lr_scheduler(self.optimizer, opt_params) + self.scheduler2 = get_lr_scheduler(self.optimizer2, opt_params) super(SSLCPS, self).train_valid() def training(self): @@ -159,7 +159,7 @@ def validation(self): self.scheduler2.step(return_value['avg_dice']) return return_value - def write_scalars(self, train_scalars, valid_scalars, glob_it): + def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): loss_scalar ={'train':train_scalars['loss'], 'valid':valid_scalars['loss']} loss_sup_scalar = {'net1':train_scalars['loss_sup1'], @@ -171,6 +171,7 @@ def write_scalars(self, train_scalars, valid_scalars, glob_it): self.summ_writer.add_scalars('loss_sup', loss_sup_scalar, glob_it) self.summ_writer.add_scalars('loss_pseudo_sup', loss_pse_sup_scalar, glob_it) self.summ_writer.add_scalars('regular_w', {'regular_w':train_scalars['regular_w']}, glob_it) + self.summ_writer.add_scalars('lr', {"lr": lr_value}, glob_it) self.summ_writer.add_scalars('dice', dice_scalar, glob_it) class_num = self.config['network']['class_num'] for c in range(class_num): diff --git a/pymic/net_run_ssl/ssl_main.py b/pymic/net_run_ssl/ssl_main.py index 54492ae..6bddf29 100644 --- a/pymic/net_run_ssl/ssl_main.py +++ b/pymic/net_run_ssl/ssl_main.py @@ -8,14 +8,17 @@ from pymic.net_run_ssl.ssl_em import SSLEntropyMinimization from pymic.net_run_ssl.ssl_mt import SSLMeanTeacher from pymic.net_run_ssl.ssl_uamt import SSLUncertaintyAwareMeanTeacher -from pymic.net_run_ssl.ssl_urpc import SSLURPC +from pymic.net_run_ssl.ssl_cct import SSLCCT from pymic.net_run_ssl.ssl_cps import SSLCPS +from pymic.net_run_ssl.ssl_urpc import SSLURPC + SSLMethodDict = {'EntropyMinimization': SSLEntropyMinimization, 'MeanTeacher': SSLMeanTeacher, 'UAMT': SSLUncertaintyAwareMeanTeacher, - 'URPC': SSLURPC, - 'CPS': SSLCPS} + 'CCT': SSLCCT, + 'CPS': SSLCPS, + 'URPC': SSLURPC} def main(): if(len(sys.argv) < 3): From 153afe56f548d2eea55b65368510a404fd92c00a Mon Sep 17 00:00:00 2001 From: taigw Date: Thu, 11 Aug 2022 10:00:56 +0800 Subject: [PATCH 14/26] update reference for ssl methods --- pymic/net_run_ssl/ssl_em.py | 5 +++-- pymic/net_run_ssl/ssl_mt.py | 6 +++++- pymic/net_run_ssl/ssl_urpc.py | 11 +++++------ 3 files changed, 13 insertions(+), 9 deletions(-) diff --git a/pymic/net_run_ssl/ssl_em.py b/pymic/net_run_ssl/ssl_em.py index bb1cd55..e4b1f19 100644 --- a/pymic/net_run_ssl/ssl_em.py +++ b/pymic/net_run_ssl/ssl_em.py @@ -15,9 +15,10 @@ class SSLEntropyMinimization(SSLSegAgent): """ Implementation of the following paper: - Yves Grandvalet and Yoshua Bengio, + Yves Grandvalet and Yoshua Bengio: Semi-supervised Learningby Entropy Minimization. - NeurIPS, 2005. + NeurIPS, 2005. + https://papers.nips.cc/paper/2004/file/96f2b50b5d3613adf9c27049b2a888c7-Paper.pdf """ def __init__(self, config, stage = 'train'): super(SSLEntropyMinimization, self).__init__(config, stage) diff --git a/pymic/net_run_ssl/ssl_mt.py b/pymic/net_run_ssl/ssl_mt.py index 7905968..aa7fbff 100644 --- a/pymic/net_run_ssl/ssl_mt.py +++ b/pymic/net_run_ssl/ssl_mt.py @@ -13,7 +13,11 @@ class SSLMeanTeacher(SSLSegAgent): """ - Training and testing agent for semi-supervised segmentation + Mean Teacher for semi-supervised learning according to the following paper: + Antti Tarvainen, Harri Valpola: Mean teachers are better role models: Weight-averaged + consistency targets improve semi-supervised deep learning results. + NeurIPS 2017. + https://arxiv.org/abs/1703.01780 """ def __init__(self, config, stage = 'train'): super(SSLMeanTeacher, self).__init__(config, stage) diff --git a/pymic/net_run_ssl/ssl_urpc.py b/pymic/net_run_ssl/ssl_urpc.py index 0513dc2..5d269fd 100644 --- a/pymic/net_run_ssl/ssl_urpc.py +++ b/pymic/net_run_ssl/ssl_urpc.py @@ -14,12 +14,11 @@ class SSLURPC(SSLSegAgent): """ Uncertainty-Rectified Pyramid Consistency according to the following paper: - Xiangde Luo, Wenjun Liao, Jieneng Chen, Tao Song, Yinan Chen, - Shichuan Zhang, Nianyong Chen, Guotai Wang, Shaoting Zhang. - Efficient Semi-supervised Gross Target Volume of Nasopharyngeal Carcinoma - Segmentation via Uncertainty Rectified Pyramid Consistency. - MICCAI 2021, pp. 318-329. - https://arxiv.org/abs/2012.07042 + Xiangde Luo, Guotai Wang*, Wenjun Liao, Jieneng Chen, Tao Song, Yinan Chen, + Shichuan Zhang, Dimitris N. Metaxas, Shaoting Zhang. + Semi-Supervised Medical Image Segmentation via Uncertainty Rectified Pyramid Consistency . + Medical Image Analysis 2022. + https://doi.org/10.1016/j.media.2022.102517 """ def training(self): class_num = self.config['network']['class_num'] From 89fd6ccf996e57f8ca60963ebfdf8455be5bcf42 Mon Sep 17 00:00:00 2001 From: taigw Date: Thu, 11 Aug 2022 12:36:39 +0800 Subject: [PATCH 15/26] update network for wsl Rename WSL classes, update learning rate scheduler and update dual-branch network --- pymic/net/net2d/unet2d.py | 61 ++++++++++++++++++++++++++ pymic/net/net2d/unet2d_cct.py | 63 +-------------------------- pymic/net/net2d/unet2d_dual_branch.py | 32 +++++++++++++- pymic/net/net_dict_seg.py | 4 +- pymic/net_run_wsl/wsl_abstract.py | 5 ++- pymic/net_run_wsl/wsl_dmpls.py | 9 ++-- pymic/net_run_wsl/wsl_em.py | 4 +- pymic/net_run_wsl/wsl_gatedcrf.py | 3 +- pymic/net_run_wsl/wsl_main.py | 24 +++++----- pymic/net_run_wsl/wsl_tv.py | 6 ++- pymic/net_run_wsl/wsl_ustm.py | 4 +- 11 files changed, 127 insertions(+), 88 deletions(-) diff --git a/pymic/net/net2d/unet2d.py b/pymic/net/net2d/unet2d.py index 703ced3..a361f0f 100644 --- a/pymic/net/net2d/unet2d.py +++ b/pymic/net/net2d/unet2d.py @@ -72,6 +72,67 @@ def forward(self, x1, x2): x = torch.cat([x2, x1], dim=1) return self.conv(x) +class Encoder(nn.Module): + def __init__(self, params): + super(Encoder, self).__init__() + self.params = params + self.in_chns = self.params['in_chns'] + self.ft_chns = self.params['feature_chns'] + self.dropout = self.params['dropout'] + assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) + + self.in_conv= ConvBlock(self.in_chns, self.ft_chns[0], self.dropout[0]) + self.down1 = DownBlock(self.ft_chns[0], self.ft_chns[1], self.dropout[1]) + self.down2 = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dropout[2]) + self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dropout[3]) + if(len(self.ft_chns) == 5): + self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4]) + + def forward(self, x): + x0 = self.in_conv(x) + x1 = self.down1(x0) + x2 = self.down2(x1) + x3 = self.down3(x2) + output = [x0, x1, x2, x3] + if(len(self.ft_chns) == 5): + x4 = self.down4(x3) + output.append(x4) + return output + +class Decoder(nn.Module): + def __init__(self, params): + super(Decoder, self).__init__() + self.params = params + self.in_chns = self.params['in_chns'] + self.ft_chns = self.params['feature_chns'] + self.dropout = self.params['dropout'] + self.n_class = self.params['class_num'] + self.bilinear = self.params['bilinear'] + + assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) + + if(len(self.ft_chns) == 5): + self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], self.dropout[3], self.bilinear) + self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], self.dropout[2], self.bilinear) + self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.bilinear) + self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.bilinear) + self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, kernel_size = 1) + + def forward(self, x): + if(len(self.ft_chns) == 5): + assert(len(x) == 5) + x0, x1, x2, x3, x4 = x + x_d3 = self.up1(x4, x3) + else: + assert(len(x) == 4) + x0, x1, x2, x3 = x + x_d3 = x3 + x_d2 = self.up2(x_d3, x2) + x_d1 = self.up3(x_d2, x1) + x_d0 = self.up4(x_d1, x0) + output = self.out_conv(x_d0) + return output + class UNet2D(nn.Module): def __init__(self, params): super(UNet2D, self).__init__() diff --git a/pymic/net/net2d/unet2d_cct.py b/pymic/net/net2d/unet2d_cct.py index 88a369f..f7558bc 100644 --- a/pymic/net/net2d/unet2d_cct.py +++ b/pymic/net/net2d/unet2d_cct.py @@ -15,68 +15,7 @@ import torch.nn.functional as F import numpy as np from torch.distributions.uniform import Uniform -from pymic.net.net2d.unet2d import ConvBlock, DownBlock, UpBlock - -class Encoder(nn.Module): - def __init__(self, params): - super(Encoder, self).__init__() - self.params = params - self.in_chns = self.params['in_chns'] - self.ft_chns = self.params['feature_chns'] - self.dropout = self.params['dropout'] - assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) - - self.in_conv= ConvBlock(self.in_chns, self.ft_chns[0], self.dropout[0]) - self.down1 = DownBlock(self.ft_chns[0], self.ft_chns[1], self.dropout[1]) - self.down2 = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dropout[2]) - self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dropout[3]) - if(len(self.ft_chns) == 5): - self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4]) - - def forward(self, x): - x0 = self.in_conv(x) - x1 = self.down1(x0) - x2 = self.down2(x1) - x3 = self.down3(x2) - output = [x0, x1, x2, x3] - if(len(self.ft_chns) == 5): - x4 = self.down4(x3) - output.append(x4) - return output - -class Decoder(nn.Module): - def __init__(self, params): - super(Decoder, self).__init__() - self.params = params - self.in_chns = self.params['in_chns'] - self.ft_chns = self.params['feature_chns'] - self.dropout = self.params['dropout'] - self.n_class = self.params['class_num'] - self.bilinear = self.params['bilinear'] - - assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) - - if(len(self.ft_chns) == 5): - self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], self.dropout[3], self.bilinear) - self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], self.dropout[2], self.bilinear) - self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.bilinear) - self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.bilinear) - self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, kernel_size = 1) - - def forward(self, x): - if(len(self.ft_chns) == 5): - assert(len(x) == 5) - x0, x1, x2, x3, x4 = x - x_d3 = self.up1(x4, x3) - else: - assert(len(x) == 4) - x0, x1, x2, x3 = x - x_d3 = x3 - x_d2 = self.up2(x_d3, x2) - x_d1 = self.up3(x_d2, x1) - x_d0 = self.up4(x_d1, x0) - output = self.out_conv(x_d0) - return output +from pymic.net.net2d.unet2d import Encoder, Decoder def _l2_normalize(d): # Normalizing per batch axis diff --git a/pymic/net/net2d/unet2d_dual_branch.py b/pymic/net/net2d/unet2d_dual_branch.py index 59ec138..9622bd0 100644 --- a/pymic/net/net2d/unet2d_dual_branch.py +++ b/pymic/net/net2d/unet2d_dual_branch.py @@ -11,10 +11,38 @@ import torch import torch.nn as nn -import numpy as np -from torch.nn.functional import interpolate from pymic.net.net2d.unet2d import * +class UNet2D_DualBranch(nn.Module): + def __init__(self, params): + super(UNet2D_DualBranch, self).__init__() + self.encoder = Encoder(params) + self.decoder1 = Decoder(params) + self.decoder2 = Decoder(params) + + def forward(self, x): + x_shape = list(x.shape) + if(len(x_shape) == 5): + [N, C, D, H, W] = x_shape + new_shape = [N*D, C, H, W] + x = torch.transpose(x, 1, 2) + x = torch.reshape(x, new_shape) + + f = self.encoder(x) + output1 = self.decoder1(f) + output2 = self.decoder2(f) + if(len(x_shape) == 5): + new_shape = [N, D] + list(output1.shape)[1:] + output1 = torch.reshape(output1, new_shape) + output1 = torch.transpose(output1, 1, 2) + output2 = torch.reshape(output2, new_shape) + output2 = torch.transpose(output2, 1, 2) + + if(self.training): + return output1, output2 + else: + return (output1 + output2)/2 + # for backup class DualBranchUNet2D(UNet2D): def __init__(self, params): params['deep_supervise'] = False diff --git a/pymic/net/net_dict_seg.py b/pymic/net/net_dict_seg.py index aa912bc..0ee554e 100644 --- a/pymic/net/net_dict_seg.py +++ b/pymic/net/net_dict_seg.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division from pymic.net.net2d.unet2d import UNet2D -from pymic.net.net2d.unet2d_dual_branch import DualBranchUNet2D +from pymic.net.net2d.unet2d_dual_branch import UNet2D_DualBranch from pymic.net.net2d.unet2d_urpc import UNet2D_URPC from pymic.net.net2d.unet2d_cct import UNet2D_CCT from pymic.net.net2d.cople_net import COPLENet @@ -14,7 +14,7 @@ SegNetDict = { 'UNet2D': UNet2D, - 'DualBranchUNet2D': DualBranchUNet2D, + 'UNet2D_DualBranch': UNet2D_DualBranch, 'UNet2D_URPC': UNet2D_URPC, 'UNet2D_CCT': UNet2D_CCT, 'COPLENet': COPLENet, diff --git a/pymic/net_run_wsl/wsl_abstract.py b/pymic/net_run_wsl/wsl_abstract.py index fe80ea5..d64063e 100644 --- a/pymic/net_run_wsl/wsl_abstract.py +++ b/pymic/net_run_wsl/wsl_abstract.py @@ -5,7 +5,7 @@ class WSLSegAgent(SegmentationAgent): """ - Training and testing agent for semi-supervised segmentation + Training and testing agent for weakly supervised segmentation """ def __init__(self, config, stage = 'train'): super(WSLSegAgent, self).__init__(config, stage) @@ -13,7 +13,7 @@ def __init__(self, config, stage = 'train'): def training(self): pass - def write_scalars(self, train_scalars, valid_scalars, glob_it): + def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): loss_scalar ={'train':train_scalars['loss'], 'valid':valid_scalars['loss']} loss_sup_scalar = {'train':train_scalars['loss_sup']} @@ -23,6 +23,7 @@ def write_scalars(self, train_scalars, valid_scalars, glob_it): self.summ_writer.add_scalars('loss_sup', loss_sup_scalar, glob_it) self.summ_writer.add_scalars('loss_reg', loss_upsup_scalar, glob_it) self.summ_writer.add_scalars('regular_w', {'regular_w':train_scalars['regular_w']}, glob_it) + self.summ_writer.add_scalars('lr', {"lr": lr_value}, glob_it) self.summ_writer.add_scalars('dice', dice_scalar, glob_it) class_num = self.config['network']['class_num'] for c in range(class_num): diff --git a/pymic/net_run_wsl/wsl_dmpls.py b/pymic/net_run_wsl/wsl_dmpls.py index 234f1a2..1e60e47 100644 --- a/pymic/net_run_wsl/wsl_dmpls.py +++ b/pymic/net_run_wsl/wsl_dmpls.py @@ -10,6 +10,7 @@ from pymic.loss.seg.dice import DiceLoss from pymic.net_run_wsl.wsl_abstract import WSLSegAgent from pymic.util.ramps import sigmoid_rampup +from pymic.util.general import keyword_match class WSLDMPLS(WSLSegAgent): """ @@ -18,12 +19,13 @@ class WSLDMPLS(WSLSegAgent): Shaoting Zhang. ScribblScribble-Supervised Medical Image Segmentation via Dual-Branch Network and Dynamically Mixed Pseudo Labels Supervision. MICCAI 2022. + https://arxiv.org/abs/2203.02106 """ def __init__(self, config, stage = 'train'): net_type = config['network']['net_type'] - if net_type not in ['DualBranchUNet2D', 'DualBranchUNet3D']: + if net_type not in ['UNet2D_DualBranch', 'UNet3D_DualBranch']: raise ValueError("""For WSL_DMPLS, a dual branch network is expected. \ - It only supports DualBranchUNet2D and DualBranchUNet3D currently.""") + It only supports UNet2D_DualBranch and UNet3D_DualBranch currently.""") super(WSLDMPLS, self).__init__(config, stage) def training(self): @@ -82,7 +84,8 @@ def training(self): loss.backward() self.optimizer.step() - self.scheduler.step() + if(not keyword_match(self.config['training']['lr_scheduler'], "ReduceLROnPlateau")): + self.scheduler.step() train_loss = train_loss + loss.item() train_loss_sup = train_loss_sup + loss_sup.item() diff --git a/pymic/net_run_wsl/wsl_em.py b/pymic/net_run_wsl/wsl_em.py index cc19600..66823c1 100644 --- a/pymic/net_run_wsl/wsl_em.py +++ b/pymic/net_run_wsl/wsl_em.py @@ -10,6 +10,7 @@ from pymic.net_run.agent_seg import SegmentationAgent from pymic.net_run_wsl.wsl_abstract import WSLSegAgent from pymic.util.ramps import sigmoid_rampup +from pymic.util.general import keyword_match class WSLEntropyMinimization(WSLSegAgent): """ @@ -60,7 +61,8 @@ def training(self): loss.backward() self.optimizer.step() - self.scheduler.step() + if(not keyword_match(self.config['training']['lr_scheduler'], "ReduceLROnPlateau")): + self.scheduler.step() train_loss = train_loss + loss.item() train_loss_sup = train_loss_sup + loss_sup.item() diff --git a/pymic/net_run_wsl/wsl_gatedcrf.py b/pymic/net_run_wsl/wsl_gatedcrf.py index af6c562..d728328 100644 --- a/pymic/net_run_wsl/wsl_gatedcrf.py +++ b/pymic/net_run_wsl/wsl_gatedcrf.py @@ -86,7 +86,8 @@ def training(self): loss.backward() self.optimizer.step() - self.scheduler.step() + if(not keyword_match(self.config['training']['lr_scheduler'], "ReduceLROnPlateau")): + self.scheduler.step() train_loss = train_loss + loss.item() train_loss_sup = train_loss_sup + loss_sup.item() diff --git a/pymic/net_run_wsl/wsl_main.py b/pymic/net_run_wsl/wsl_main.py index 595aa3e..916e1d8 100644 --- a/pymic/net_run_wsl/wsl_main.py +++ b/pymic/net_run_wsl/wsl_main.py @@ -5,19 +5,19 @@ import os import sys from pymic.util.parse_config import * -from pymic.net_run_wsl.wsl_em import WSL_EntropyMinimization -from pymic.net_run_wsl.wsl_gatedcrf import WSL_GatedCRF -from pymic.net_run_wsl.wsl_mumford_shah import WSL_MumfordShah -from pymic.net_run_wsl.wsl_tv import WSL_TotalVariation -from pymic.net_run_wsl.wsl_ustm import WSL_USTM -from pymic.net_run_wsl.wsl_dmpls import WSL_DMPLS +from pymic.net_run_wsl.wsl_em import WSLEntropyMinimization +from pymic.net_run_wsl.wsl_gatedcrf import WSLGatedCRF +from pymic.net_run_wsl.wsl_mumford_shah import WSLMumfordShah +from pymic.net_run_wsl.wsl_tv import WSLTotalVariation +from pymic.net_run_wsl.wsl_ustm import WSLUSTM +from pymic.net_run_wsl.wsl_dmpls import WSLDMPLS -WSLMethodDict = {'EntropyMinimization': WSL_EntropyMinimization, - 'GatedCRF': WSL_GatedCRF, - 'MumfordShah': WSL_MumfordShah, - 'TotalVariation': WSL_TotalVariation, - 'USTM': WSL_USTM, - 'DMPLS': WSL_DMPLS} +WSLMethodDict = {'EntropyMinimization': WSLEntropyMinimization, + 'GatedCRF': WSLGatedCRF, + 'MumfordShah': WSLMumfordShah, + 'TotalVariation': WSLTotalVariation, + 'USTM': WSLUSTM, + 'DMPLS': WSLDMPLS} def main(): if(len(sys.argv) < 3): diff --git a/pymic/net_run_wsl/wsl_tv.py b/pymic/net_run_wsl/wsl_tv.py index 9492150..fde1c10 100644 --- a/pymic/net_run_wsl/wsl_tv.py +++ b/pymic/net_run_wsl/wsl_tv.py @@ -9,10 +9,11 @@ from pymic.loss.seg.ssl import TotalVariationLoss from pymic.net_run_wsl.wsl_abstract import WSLSegAgent from pymic.util.ramps import sigmoid_rampup +from pymic.util.general import keyword_match class WSLTotalVariation(WSLSegAgent): """ - Training and testing agent for semi-supervised segmentation + Weakly suepervised segmentation with Total Variation Regularization. """ def __init__(self, config, stage = 'train'): super(WSLTotalVariation, self).__init__(config, stage) @@ -59,7 +60,8 @@ def training(self): # if (self.config['training']['use']) loss.backward() self.optimizer.step() - self.scheduler.step() + if(not keyword_match(self.config['training']['lr_scheduler'], "ReduceLROnPlateau")): + self.scheduler.step() train_loss = train_loss + loss.item() train_loss_sup = train_loss_sup + loss_sup.item() diff --git a/pymic/net_run_wsl/wsl_ustm.py b/pymic/net_run_wsl/wsl_ustm.py index 32e79df..dd556e7 100644 --- a/pymic/net_run_wsl/wsl_ustm.py +++ b/pymic/net_run_wsl/wsl_ustm.py @@ -11,6 +11,7 @@ from pymic.net.net_dict_seg import SegNetDict from pymic.net_run_wsl.wsl_abstract import WSLSegAgent from pymic.util.ramps import sigmoid_rampup +from pymic.util.general import keyword_match class WSLUSTM(WSLSegAgent): """ @@ -108,7 +109,8 @@ def training(self): loss.backward() self.optimizer.step() - self.scheduler.step() + if(not keyword_match(self.config['training']['lr_scheduler'], "ReduceLROnPlateau")): + self.scheduler.step() # update EMA alpha = wsl_cfg.get('ema_decay', 0.99) From dd2f79781b598ed95f54a3cef851ca034164463f Mon Sep 17 00:00:00 2001 From: taigw Date: Thu, 11 Aug 2022 15:22:42 +0800 Subject: [PATCH 16/26] Update wsl_gatedcrf.py --- pymic/net_run_wsl/wsl_gatedcrf.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/pymic/net_run_wsl/wsl_gatedcrf.py b/pymic/net_run_wsl/wsl_gatedcrf.py index d728328..2ae8318 100644 --- a/pymic/net_run_wsl/wsl_gatedcrf.py +++ b/pymic/net_run_wsl/wsl_gatedcrf.py @@ -9,10 +9,16 @@ from pymic.loss.seg.gatedcrf import ModelLossSemsegGatedCRF from pymic.net_run_wsl.wsl_abstract import WSLSegAgent from pymic.util.ramps import sigmoid_rampup +from pymic.util.general import keyword_match class WSLGatedCRF(WSLSegAgent): """ - Training and testing agent for semi-supervised segmentation + Implementation of the Gated CRF Loss for Weakly Supervised Semantic Image Segmentation. + Anton Obukhov, Stamatios Georgoulis, Dengxin Dai, Luc Van Gool: + Gated CRF Loss for Weakly Supervised Semantic Image Segmentation. + CoRR, abs/1906.04651, 2019 + http://arxiv.org/abs/1906.04651 + } """ def __init__(self, config, stage = 'train'): super(WSLGatedCRF, self).__init__(config, stage) From 41152e2fa85f6a7e26e0d332a0335acdc6abba16 Mon Sep 17 00:00:00 2001 From: taigw Date: Thu, 11 Aug 2022 16:03:42 +0800 Subject: [PATCH 17/26] update mumford shah method --- pymic/loss/seg/mumford_shah.py | 27 ++------------------------- pymic/net_run_wsl/wsl_mumford_shah.py | 9 +++++++-- 2 files changed, 9 insertions(+), 27 deletions(-) diff --git a/pymic/loss/seg/mumford_shah.py b/pymic/loss/seg/mumford_shah.py index eeaa250..f167b71 100644 --- a/pymic/loss/seg/mumford_shah.py +++ b/pymic/loss/seg/mumford_shah.py @@ -4,37 +4,14 @@ import torch import torch.nn as nn -class DiceLoss(nn.Module): - def __init__(self, params = None): - super(DiceLoss, self).__init__() - if(params is None): - self.softmax = True - else: - self.softmax = params.get('loss_softmax', True) - - def forward(self, loss_input_dict): - predict = loss_input_dict['prediction'] - soft_y = loss_input_dict['ground_truth'] - - if(isinstance(predict, (list, tuple))): - predict = predict[0] - if(self.softmax): - predict = nn.Softmax(dim = 1)(predict) - predict = reshape_tensor_to_2D(predict) - soft_y = reshape_tensor_to_2D(soft_y) - dice_score = get_classwise_dice(predict, soft_y) - dice_loss = 1.0 - dice_score.mean() - return dice_loss - class MumfordShahLoss(nn.Module): """ Implementation of Mumford Shah Loss in this paper: - Boah Kim and Jong Chul Ye, Mumford–Shah Loss Functional + Boah Kim and Jong Chul Ye: Mumford–Shah Loss Functional for Image Segmentation With Deep Learning. IEEE TIP, 2019. The oringial implementation is availabel at: https://github.com/jongcye/CNN_MumfordShah_Loss - - currently only 2D version is supported. + Currently only 2D version is supported. """ def __init__(self, params = None): super(MumfordShahLoss, self).__init__() diff --git a/pymic/net_run_wsl/wsl_mumford_shah.py b/pymic/net_run_wsl/wsl_mumford_shah.py index f642e59..095a0f6 100644 --- a/pymic/net_run_wsl/wsl_mumford_shah.py +++ b/pymic/net_run_wsl/wsl_mumford_shah.py @@ -9,10 +9,14 @@ from pymic.loss.seg.mumford_shah import MumfordShahLoss from pymic.net_run_wsl.wsl_abstract import WSLSegAgent from pymic.util.ramps import sigmoid_rampup +from pymic.util.general import keyword_match class WSLMumfordShah(WSLSegAgent): """ - Training and testing agent for semi-supervised segmentation + Weakly supervised learning with Mumford Shah Loss according to this paper: + Boah Kim and Jong Chul Ye: Mumford–Shah Loss Functional + for Image Segmentation With Deep Learning. IEEE TIP, 2019. + https://doi.org/10.1109/TIP.2019.2941265 """ def __init__(self, config, stage = 'train'): super(WSLMumfordShah, self).__init__(config, stage) @@ -61,7 +65,8 @@ def training(self): # if (self.config['training']['use']) loss.backward() self.optimizer.step() - self.scheduler.step() + if(not keyword_match(self.config['training']['lr_scheduler'], "ReduceLROnPlateau")): + self.scheduler.step() train_loss = train_loss + loss.item() train_loss_sup = train_loss_sup + loss_sup.item() From cb61d458b5db53ae0f789fa767105d85bbeef71a Mon Sep 17 00:00:00 2001 From: taigw Date: Sat, 13 Aug 2022 09:23:19 +0800 Subject: [PATCH 18/26] Update wsl_ustm.py add reference --- pymic/net_run_wsl/wsl_ustm.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/pymic/net_run_wsl/wsl_ustm.py b/pymic/net_run_wsl/wsl_ustm.py index dd556e7..6083069 100644 --- a/pymic/net_run_wsl/wsl_ustm.py +++ b/pymic/net_run_wsl/wsl_ustm.py @@ -15,7 +15,12 @@ class WSLUSTM(WSLSegAgent): """ - Training and testing agent for semi-supervised segmentation + USTM for scribble-supervised segmentation according to the following paper: + Xiaoming Liu, Quan Yuan, Yaozong Gao, Helei He, Shuo Wang, Xiao Tang, + Jinshan Tang, Dinggang Shen: + Weakly Supervised Segmentation of COVID19 Infection with Scribble Annotation on CT Images. + Patter Recognition, 2022. + https://doi.org/10.1016/j.patcog.2021.108341 """ def __init__(self, config, stage = 'train'): super(WSLUSTM, self).__init__(config, stage) From 251e4d1a42083907b0d52a80b36f04a5cd9a137f Mon Sep 17 00:00:00 2001 From: taigw Date: Wed, 17 Aug 2022 16:07:54 +0800 Subject: [PATCH 19/26] update SSL, WSL and NLL update rampup and lr scheduler.step --- pymic/loss/loss_dict_seg.py | 4 +- pymic/loss/seg/ce.py | 23 ++-- pymic/loss/seg/slsr.py | 9 +- pymic/net_run/agent_seg.py | 9 +- pymic/net_run/get_optimizer.py | 2 + pymic/net_run_nll/nll_cl.py | 52 ++++---- pymic/net_run_nll/nll_co_teaching.py | 109 ++++++---------- pymic/net_run_nll/nll_main.py | 37 ++++++ pymic/net_run_nll/nll_trinet.py | 178 ++++++++++++++++++++++++++ pymic/net_run_ssl/ssl_abstract.py | 4 - pymic/net_run_ssl/ssl_cct.py | 22 ++-- pymic/net_run_ssl/ssl_cps.py | 83 +++++------- pymic/net_run_ssl/ssl_em.py | 20 +-- pymic/net_run_ssl/ssl_mt.py | 20 +-- pymic/net_run_ssl/ssl_uamt.py | 22 ++-- pymic/net_run_ssl/ssl_urpc.py | 20 ++- pymic/net_run_wsl/wsl_dmpls.py | 19 ++- pymic/net_run_wsl/wsl_em.py | 21 ++- pymic/net_run_wsl/wsl_gatedcrf.py | 19 ++- pymic/net_run_wsl/wsl_mumford_shah.py | 19 ++- pymic/net_run_wsl/wsl_tv.py | 18 +-- pymic/net_run_wsl/wsl_ustm.py | 21 ++- pymic/util/ramps.py | 27 ++-- 23 files changed, 457 insertions(+), 301 deletions(-) create mode 100644 pymic/net_run_nll/nll_main.py create mode 100644 pymic/net_run_nll/nll_trinet.py diff --git a/pymic/loss/loss_dict_seg.py b/pymic/loss/loss_dict_seg.py index 929ec43..a8a53ad 100644 --- a/pymic/loss/loss_dict_seg.py +++ b/pymic/loss/loss_dict_seg.py @@ -1,14 +1,14 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division import torch.nn as nn -from pymic.loss.seg.ce import CrossEntropyLoss, GeneralizedCrossEntropyLoss +from pymic.loss.seg.ce import CrossEntropyLoss, GeneralizedCELoss from pymic.loss.seg.dice import DiceLoss, FocalDiceLoss, NoiseRobustDiceLoss from pymic.loss.seg.slsr import SLSRLoss from pymic.loss.seg.exp_log import ExpLogLoss from pymic.loss.seg.mse import MSELoss, MAELoss SegLossDict = {'CrossEntropyLoss': CrossEntropyLoss, - 'GeneralizedCrossEntropyLoss': GeneralizedCrossEntropyLoss, + 'GeneralizedCELoss': GeneralizedCELoss, 'SLSRLoss': SLSRLoss, 'DiceLoss': DiceLoss, 'FocalDiceLoss': FocalDiceLoss, diff --git a/pymic/loss/seg/ce.py b/pymic/loss/seg/ce.py index dadeba7..da2bf14 100644 --- a/pymic/loss/seg/ce.py +++ b/pymic/loss/seg/ce.py @@ -59,34 +59,36 @@ def forward(self, loss_input_dict): ce = torch.mean(ce) return ce -class GeneralizedCrossEntropyLoss(nn.Module): +class GeneralizedCELoss(nn.Module): """ Generalized cross entropy loss to deal with noisy labels. Z. Zhang et al. Generalized Cross Entropy Loss for Training Deep Neural Networks with Noisy Labels, NeurIPS 2018. """ def __init__(self, params): - super(GeneralizedCrossEntropyLoss, self).__init__() - self.enable_pix_weight = params['GeneralizedCrossEntropyLoss_Enable_Pixel_Weight'.lower()] - self.enable_cls_weight = params['GeneralizedCrossEntropyLoss_Enable_Class_Weight'.lower()] - self.q = params['GeneralizedCrossEntropyLoss_q'.lower()] + """ + q: in (0, 1), becmomes MAE when q = 1 + """ + super(GeneralizedCELoss, self).__init__() + self.enable_pix_weight = params.get('GeneralizedCELoss_Enable_Pixel_Weight', False) + self.enable_cls_weight = params.get('GeneralizedCELoss_Enable_Class_Weight', False) + self.q = params.get('GeneralizedCELoss_q', 0.5) + self.softmax = params.get('loss_softmax', True) def forward(self, loss_input_dict): predict = loss_input_dict['prediction'] - soft_y = loss_input_dict['ground_truth'] - pix_w = loss_input_dict['pixel_weight'] - cls_w = loss_input_dict['class_weight'] - softmax = loss_input_dict['softmax'] + soft_y = loss_input_dict['ground_truth'] if(isinstance(predict, (list, tuple))): predict = predict[0] - if(softmax): + if(self.softmax): predict = nn.Softmax(dim = 1)(predict) predict = reshape_tensor_to_2D(predict) soft_y = reshape_tensor_to_2D(soft_y) gce = (1.0 - torch.pow(predict, self.q)) / self.q * soft_y if(self.enable_cls_weight): + cls_w = loss_input_dict.get('class_weight', None) if(cls_w is None): raise ValueError("Class weight is enabled but not defined") gce = torch.sum(gce * cls_w, dim = 1) @@ -94,6 +96,7 @@ def forward(self, loss_input_dict): gce = torch.sum(gce, dim = 1) if(self.enable_pix_weight): + pix_w = loss_input_dict.get('pixel_weight', None) if(pix_w is None): raise ValueError("Pixel weight is enabled but not defined") pix_w = reshape_tensor_to_2D(pix_w) diff --git a/pymic/loss/seg/slsr.py b/pymic/loss/seg/slsr.py index 6ad60b3..706d2fc 100644 --- a/pymic/loss/seg/slsr.py +++ b/pymic/loss/seg/slsr.py @@ -2,8 +2,10 @@ """ Spatial Label Smoothing Regularization (SLSR) loss for learning from noisy annotatins according to the following paper: - Minqing Zhang, Jiantao Gao et al., Characterizing Label Errors: - Confident Learning for Noisy-Labeled Image Segmentation, MICCAI 2020. + Minqing Zhang, Jiantao Gao et al.: + Characterizing Label Errors: Confident Learning for Noisy-Labeled Image + Segmentation, MICCAI 2020. + https://link.springer.com/chapter/10.1007/978-3-030-59710-8_70 """ from __future__ import print_function, division @@ -17,7 +19,7 @@ def __init__(self, params): if(params is None): params = {} self.softmax = params.get('loss_softmax', True) - self.epsilon = params.get('slsrloss_softmax', 0.25) + self.epsilon = params.get('slsrloss_epsilon', 0.25) def forward(self, loss_input_dict): predict = loss_input_dict['prediction'] @@ -35,7 +37,6 @@ def forward(self, loss_input_dict): soft_y = reshape_tensor_to_2D(soft_y) if(pix_w is not None): pix_w = reshape_tensor_to_2D(pix_w > 0).float() - # smooth labels for pixels in the unconfident mask smooth_y = (soft_y - 0.5) * (0.5 - self.epsilon) / 0.5 + 0.5 smooth_y = pix_w * smooth_y + (1 - pix_w) * soft_y diff --git a/pymic/net_run/agent_seg.py b/pymic/net_run/agent_seg.py index 38c75ae..d8e74a2 100644 --- a/pymic/net_run/agent_seg.py +++ b/pymic/net_run/agent_seg.py @@ -10,6 +10,7 @@ import numpy as np import torch.nn as nn import torch.optim as optim +from torch.optim import lr_scheduler import torch.nn.functional as F from datetime import datetime from tensorboardX import SummaryWriter @@ -27,7 +28,6 @@ from pymic.transform.trans_dict import TransformDict from pymic.util.post_process import PostProcessDict from pymic.util.image_process import convert_label -from pymic.util.general import keyword_match class SegmentationAgent(NetRunAgent): def __init__(self, config, stage = 'train'): @@ -164,7 +164,7 @@ def set_postprocessor(self, postprocessor): def training(self): class_num = self.config['network']['class_num'] iter_valid = self.config['training']['iter_valid'] - train_loss = 0 + train_loss = 0 train_dice_list = [] self.net.train() for it in range(iter_valid): @@ -201,7 +201,8 @@ def training(self): loss = self.get_loss_value(data, outputs, labels_prob) loss.backward() self.optimizer.step() - if(not keyword_match(self.config['training']['lr_scheduler'], "ReduceLROnPlateau")): + if(self.scheduler is not None and \ + not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): self.scheduler.step() train_loss = train_loss + loss.item() @@ -258,7 +259,7 @@ def validation(self): valid_cls_dice = np.asarray(valid_dice_list).mean(axis = 0) valid_avg_dice = valid_cls_dice.mean() - if(keyword_match(self.config['training']['lr_scheduler'], "ReduceLROnPlateau")): + if(isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): self.scheduler.step(valid_avg_dice) valid_scalers = {'loss': valid_avg_loss, 'avg_dice': valid_avg_dice,\ diff --git a/pymic/net_run/get_optimizer.py b/pymic/net_run/get_optimizer.py index e475286..c4504de 100644 --- a/pymic/net_run/get_optimizer.py +++ b/pymic/net_run/get_optimizer.py @@ -38,6 +38,8 @@ def get_optimizer(name, net_params, optim_params): def get_lr_scheduler(optimizer, sched_params): name = sched_params["lr_scheduler"] + if(name is None): + return None lr_gamma = sched_params["lr_gamma"] if(keyword_match(name, "ReduceLROnPlateau")): patience_it = sched_params["ReduceLROnPlateau_patience".lower()] diff --git a/pymic/net_run_nll/nll_cl.py b/pymic/net_run_nll/nll_cl.py index 8792ccd..8173471 100644 --- a/pymic/net_run_nll/nll_cl.py +++ b/pymic/net_run_nll/nll_cl.py @@ -14,6 +14,7 @@ import sys import torch import numpy as np +import pandas as pd import torch.nn as nn import torchvision.transforms as transforms from PIL import Image @@ -45,9 +46,9 @@ def get_confident_map(gt, pred, CL_type = 'both'): noise = cleanlab.pruning.get_noise_indices(gt, prob, prune_method=CL_type, n_jobs=1) return noise -class NLLConfidentLeran(SegmentationAgent): +class NLLConfidentLearn(SegmentationAgent): def __init__(self, config, stage = 'test'): - super(NLLConfidentLeran, self).__init__(config, stage) + super(NLLConfidentLearn, self).__init__(config, stage) def infer_with_cl(self): device_ids = self.config['testing']['gpus'] @@ -93,16 +94,6 @@ def test_time_dropout(m): filename_list.append(names) images = images.to(device) - # for debug - # for i in range(images.shape[0]): - # image_i = images[i][0] - # label_i = images[i][0] - # image_name = "temp/{0:}_image.nii.gz".format(names[0]) - # label_name = "temp/{0:}_label.nii.gz".format(names[0]) - # save_nd_array_as_image(image_i, image_name, reference_name = None) - # save_nd_array_as_image(label_i, label_name, reference_name = None) - # continue - pred = self.inferer.run(self.net, images) # convert tensor to numpy if(isinstance(pred, (tuple, list))): @@ -142,15 +133,10 @@ def test_time_dropout(m): dst_path = os.path.join(save_dir, filename) conf_map.save(dst_path) - def run(self): - self.create_dataset() - self.create_network() - self.infer_with_cl() - -def main(): +def get_confidence_map(): if(len(sys.argv) < 2): print('Number of arguments should be 3. e.g.') - print(' python cl.py config.cfg') + print(' python nll_cl.py config.cfg') exit() cfg_file = str(sys.argv[1]) config = parse_config(cfg_file) @@ -172,17 +158,35 @@ def main(): transform_list.append(one_transform) data_transform = transforms.Compose(transform_list) print('transform list', transform_list) - csv_file = config['dataset']['train_csv'] + csv_file = config['dataset']['train_csv'] + modal_num = config['dataset'].get('modal_num', 1) dataset = NiftyDataset(root_dir = config['dataset']['root_dir'], csv_file = csv_file, - modal_num = config['dataset']['modal_num'], + modal_num = modal_num, with_label= True, transform = data_transform ) - agent = NLLConfidentLeran(config, 'test') + agent = NLLConfidentLearn(config, 'test') agent.set_datasets(None, None, dataset) agent.transform_list = transform_list - agent.run() + agent.create_dataset() + agent.create_network() + agent.infer_with_cl() + + # create training csv for confidence learning + df_train = pd.read_csv(csv_file) + pixel_weight = [] + weight_dir = config['testing']['output_dir'] + "_conf" + for i in range(len(df_train["label"])): + lab_name = df_train["label"][i].split('/')[-1] + weight_name = "../" + weight_dir + '/' + lab_name + pixel_weight.append(weight_name) + train_cl_dict = {"image": df_train["image"], + "pixel_weight": pixel_weight, + "label": df_train["label"]} + train_cl_csv = csv_file.replace(".csv", "_cl.csv") + df_cl = pd.DataFrame.from_dict(train_cl_dict) + df_cl.to_csv(train_cl_csv, index = False) if __name__ == "__main__": - main() \ No newline at end of file + get_confidence_map() \ No newline at end of file diff --git a/pymic/net_run_nll/nll_co_teaching.py b/pymic/net_run_nll/nll_co_teaching.py index e1392eb..bcaec4e 100644 --- a/pymic/net_run_nll/nll_co_teaching.py +++ b/pymic/net_run_nll/nll_co_teaching.py @@ -11,23 +11,37 @@ """ from __future__ import print_function, division import logging +import os +import sys import numpy as np import torch import torch.nn as nn import torch.optim as optim +from torch.optim import lr_scheduler from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.loss.seg.util import reshape_tensor_to_2D -from pymic.util.ramps import sigmoid_rampup -from pymic.net_run.get_optimizer import get_optimiser from pymic.net_run.agent_seg import SegmentationAgent from pymic.net.net_dict_seg import SegNetDict - -import logging -import os -import sys from pymic.util.parse_config import * +from pymic.util.ramps import get_rampup_ratio + +class BiNet(nn.Module): + def __init__(self, params): + super(BiNet, self).__init__() + net_name = params['net_type'] + self.net1 = SegNetDict[net_name](params) + self.net2 = SegNetDict[net_name](params) + + def forward(self, x): + out1 = self.net1(x) + out2 = self.net2(x) + + if(self.training): + return out1, out2 + else: + return (out1 + out2) / 3 class NLLCoTeaching(SegmentationAgent): """ @@ -37,48 +51,27 @@ class NLLCoTeaching(SegmentationAgent): """ def __init__(self, config, stage = 'train'): super(NLLCoTeaching, self).__init__(config, stage) - self.net2 = None - self.optimizer2 = None - self.scheduler2 = None loss_type = config['training']["loss_type"] if(loss_type != "CrossEntropyLoss"): logging.warn("only CrossEntropyLoss supported for" + " coteaching, the specified loss {0:} is ingored".format(loss_type)) def create_network(self): - super(NLLCoTeaching, self).create_network() - if(self.net2 is None): - net_name = self.config['network']['net_type'] - if(net_name not in SegNetDict): - raise ValueError("Undefined network {0:}".format(net_name)) - self.net2 = SegNetDict[net_name](self.config['network']) + if(self.net is None): + self.net = BiNet(self.config['network']) if(self.tensor_type == 'float'): - self.net2.float() + self.net.float() else: - self.net2.double() - - def train_valid(self): - # create optimizor for the second network - if(self.optimizer2 is None): - self.optimizer2 = get_optimiser(self.config['training']['optimizer'], - self.net2.parameters(), - self.config['training']) - last_iter = -1 - # if(self.checkpoint is not None): - # self.optimizer2.load_state_dict(self.checkpoint['optimizer_state_dict']) - # last_iter = self.checkpoint['iteration'] - 1 - if(self.scheduler2 is None): - self.scheduler2 = optim.lr_scheduler.MultiStepLR(self.optimizer2, - self.config['training']['lr_milestones'], - self.config['training']['lr_gamma'], - last_epoch = last_iter) - super(NLLCoTeaching, self).train_valid() + self.net.double() def training(self): class_num = self.config['network']['class_num'] iter_valid = self.config['training']['iter_valid'] - select_ratio = self.config['training']['co_teaching_select_ratio'] - rampup_length = self.config['training']['co_teaching_rampup_length'] + nll_cfg = self.config['noisy_label_learning'] + select_ratio = nll_cfg['co_teaching_select_ratio'] + iter_max = self.config['training']['iter_max'] + rampup_start = nll_cfg.get('rampup_start', 0) + rampup_end = nll_cfg.get('rampup_end', iter_max) train_loss_no_select1 = 0 train_loss_no_select2 = 0 @@ -86,8 +79,6 @@ def training(self): train_loss2 = 0 train_dice_list = [] self.net.train() - self.net2.train() - self.net2.to(self.device) for it in range(iter_valid): try: data = next(self.trainIter) @@ -102,11 +93,9 @@ def training(self): # zero the parameter gradients self.optimizer.zero_grad() - self.optimizer2.zero_grad() # forward + backward + optimize - outputs1 = self.net(inputs) - outputs2 = self.net2(inputs) + outputs1, outputs2 = self.net(inputs) prob1 = nn.Softmax(dim = 1)(outputs1) prob2 = nn.Softmax(dim = 1)(outputs2) @@ -122,8 +111,9 @@ def training(self): loss2 = torch.sum(loss2, dim = 1) # shape is [N] ind_2_sorted = torch.argsort(loss2) - forget_ratio = (1 - select_ratio) * self.glob_it / rampup_length - remb_ratio = max(select_ratio, 1 - forget_ratio) + rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") + forget_ratio = (1 - select_ratio) * rampup_ratio + remb_ratio = 1 - forget_ratio num_remb = int(remb_ratio * len(loss1)) ind_1_update = ind_1_sorted[:num_remb] @@ -134,22 +124,17 @@ def training(self): loss = loss1_select.mean() + loss2_select.mean() - # if (self.config['training']['use']) loss.backward() self.optimizer.step() - self.scheduler.step() - self.optimizer2.step() - self.scheduler2.step() + if(self.scheduler is not None and \ + not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): + self.scheduler.step() train_loss_no_select1 = train_loss_no_select1 + loss1.mean().item() train_loss_no_select2 = train_loss_no_select2 + loss2.mean().item() train_loss1 = train_loss1 + loss1_select.mean().item() train_loss2 = train_loss2 + loss2_select.mean().item() - # get dice evaluation for each class in annotated images - # if(isinstance(outputs1, tuple) or isinstance(outputs1, list)): - # outputs1 = outputs1[0] - outputs1_argmax = torch.argmax(outputs1, dim = 1, keepdim = True) soft_out1 = get_soft_label(outputs1_argmax, class_num, self.tensor_type) soft_out1, labels_prob = reshape_prediction_and_ground_truth(soft_out1, labels_prob) @@ -169,7 +154,7 @@ def training(self): 'select_ratio':remb_ratio, 'avg_dice':train_avg_dice, 'class_dice': train_cls_dice} return train_scalers - def write_scalars(self, train_scalars, valid_scalars, glob_it): + def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): loss_scalar ={'train':train_scalars['loss'], 'valid':valid_scalars['loss']} loss_no_select_scalar = {'net1':train_scalars['loss_no_select1'], @@ -179,6 +164,7 @@ def write_scalars(self, train_scalars, valid_scalars, glob_it): self.summ_writer.add_scalars('loss', loss_scalar, glob_it) self.summ_writer.add_scalars('loss_no_select', loss_no_select_scalar, glob_it) self.summ_writer.add_scalars('select_ratio', {'select_ratio':train_scalars['select_ratio']}, glob_it) + self.summ_writer.add_scalars('lr', {"lr": lr_value}, glob_it) self.summ_writer.add_scalars('dice', dice_scalar, glob_it) class_num = self.config['network']['class_num'] for c in range(class_num): @@ -192,22 +178,3 @@ def write_scalars(self, train_scalars, valid_scalars, glob_it): logging.info('valid loss {0:.4f}, avg dice {1:.4f} '.format( valid_scalars['loss'], valid_scalars['avg_dice']) + "[" + \ ' '.join("{0:.4f}".format(x) for x in valid_scalars['class_dice']) + "]") - -if __name__ == "__main__": - if(len(sys.argv) < 3): - print('Number of arguments should be 3. e.g.') - print(' pymic_ssl train config.cfg') - exit() - stage = str(sys.argv[1]) - cfg_file = str(sys.argv[2]) - config = parse_config(cfg_file) - config = synchronize_config(config) - log_dir = config['training']['ckpt_save_dir'] - if(not os.path.exists(log_dir)): - os.mkdir(log_dir) - logging.basicConfig(filename=log_dir+"/log.txt", level=logging.INFO, - format='%(message)s') - logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) - logging_config(config) - agent = NLLCoTeaching(config, stage) - agent.run() \ No newline at end of file diff --git a/pymic/net_run_nll/nll_main.py b/pymic/net_run_nll/nll_main.py new file mode 100644 index 0000000..d1ae7a1 --- /dev/null +++ b/pymic/net_run_nll/nll_main.py @@ -0,0 +1,37 @@ + +# -*- coding: utf-8 -*- +from __future__ import print_function, division +import logging +import os +import sys +from pymic.util.parse_config import * +from pymic.net_run_nll.nll_co_teaching import NLLCoTeaching +from pymic.net_run_nll.nll_trinet import NLLTriNet + +NLLMethodDict = {'CoTeaching': NLLCoTeaching, + "TriNet": NLLTriNet} + +def main(): + if(len(sys.argv) < 3): + print('Number of arguments should be 3. e.g.') + print(' pymic_nll train config.cfg') + exit() + stage = str(sys.argv[1]) + cfg_file = str(sys.argv[2]) + config = parse_config(cfg_file) + config = synchronize_config(config) + log_dir = config['training']['ckpt_save_dir'] + if(not os.path.exists(log_dir)): + os.mkdir(log_dir) + logging.basicConfig(filename=log_dir+"/log.txt", level=logging.INFO, + format='%(message)s') + logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) + logging_config(config) + nll_method = config['noisy_label_learning']['nll_method'] + agent = NLLMethodDict[nll_method](config, stage) + agent.run() + +if __name__ == "__main__": + main() + + \ No newline at end of file diff --git a/pymic/net_run_nll/nll_trinet.py b/pymic/net_run_nll/nll_trinet.py new file mode 100644 index 0000000..eb0ecdd --- /dev/null +++ b/pymic/net_run_nll/nll_trinet.py @@ -0,0 +1,178 @@ +# -*- coding: utf-8 -*- +""" +Implementation of Co-teaching for learning from noisy samples for +segmentation tasks according to the following paper: + Bo Han et al., Co-teaching: Robust Training of Deep NeuralNetworks + with Extremely Noisy Labels, NeurIPS, 2018 +The author's original implementation was: +https://github.com/bhanML/Co-teaching + + +""" +from __future__ import print_function, division +import logging +import os +import sys +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from torch.optim import lr_scheduler +from pymic.loss.seg.util import get_soft_label +from pymic.loss.seg.util import reshape_prediction_and_ground_truth +from pymic.loss.seg.util import get_classwise_dice +from pymic.loss.seg.util import reshape_tensor_to_2D +from pymic.net_run.agent_seg import SegmentationAgent +from pymic.net.net_dict_seg import SegNetDict +from pymic.util.parse_config import * +from pymic.util.ramps import get_rampup_ratio + + + +class TriNet(nn.Module): + def __init__(self, params): + super(TriNet, self).__init__() + net_name = params['net_type'] + self.net1 = SegNetDict[net_name](params) + self.net2 = SegNetDict[net_name](params) + self.net3 = SegNetDict[net_name](params) + + def forward(self, x): + out1 = self.net1(x) + out2 = self.net2(x) + out3 = self.net3(x) + + if(self.training): + return out1, out2, out3 + else: + return (out1 + out2 + out3) / 3 + +class NLLTriNet(SegmentationAgent): + """ + Co-teaching: Robust Training of Deep Neural Networks with Extremely + Noisy Labels + https://arxiv.org/abs/1804.06872 + """ + def __init__(self, config, stage = 'train'): + super(NLLTriNet, self).__init__(config, stage) + + def create_network(self): + if(self.net is None): + self.net = TriNet(self.config['network']) + if(self.tensor_type == 'float'): + self.net.float() + else: + self.net.double() + + def get_loss_and_confident_mask(self, pred, labels_prob, conf_ratio): + prob = nn.Softmax(dim = 1)(pred) + prob_2d = reshape_tensor_to_2D(prob) * 0.999 + 5e-4 + y_2d = reshape_tensor_to_2D(labels_prob) + + loss = - y_2d* torch.log(prob_2d) + loss = torch.sum(loss, dim = 1) # shape is [N] + threshold = torch.quantile(loss, conf_ratio) + mask = loss < threshold + return loss, mask + + def training(self): + class_num = self.config['network']['class_num'] + iter_valid = self.config['training']['iter_valid'] + nll_cfg = self.config['noisy_label_learning'] + iter_max = self.config['training']['iter_max'] + select_ratio = nll_cfg['trinet_select_ratio'] + rampup_start = nll_cfg.get('rampup_start', 0) + rampup_end = nll_cfg.get('rampup_end', iter_max) + + train_loss_no_select1 = 0 + train_loss_no_select2 = 0 + train_loss1, train_loss2, train_loss3 = 0, 0, 0 + train_dice_list = [] + self.net.train() + for it in range(iter_valid): + try: + data = next(self.trainIter) + except StopIteration: + self.trainIter = iter(self.train_loader) + data = next(self.trainIter) + + # get the inputs + inputs = self.convert_tensor_type(data['image']) + labels_prob = self.convert_tensor_type(data['label_prob']) + inputs, labels_prob = inputs.to(self.device), labels_prob.to(self.device) + + # zero the parameter gradients + self.optimizer.zero_grad() + + # forward + backward + optimize + outputs1, outputs2, outputs3 = self.net(inputs) + + rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end) + forget_ratio = (1 - select_ratio) * rampup_ratio + remb_ratio = 1 - forget_ratio + + loss1, mask1 = self.get_loss_and_confident_mask(outputs1, labels_prob, remb_ratio) + loss2, mask2 = self.get_loss_and_confident_mask(outputs2, labels_prob, remb_ratio) + loss3, mask3 = self.get_loss_and_confident_mask(outputs3, labels_prob, remb_ratio) + mask12, mask13, mask23 = mask1 * mask2, mask1 * mask3, mask2 * mask3 + mask12, mask13, mask23 = mask12.detach(), mask13.detach(), mask23.detach() + + loss1_avg = torch.sum(loss1 * mask23) / mask23.sum() + loss2_avg = torch.sum(loss2 * mask13) / mask13.sum() + loss3_avg = torch.sum(loss3 * mask12) / mask12.sum() + loss = (loss1_avg + loss2_avg + loss3_avg) / 3 + + loss.backward() + self.optimizer.step() + if(self.scheduler is not None and \ + not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): + self.scheduler.step() + + train_loss_no_select1 = train_loss_no_select1 + loss1.mean().item() + train_loss_no_select2 = train_loss_no_select2 + loss2.mean().item() + train_loss1 = train_loss1 + loss1_avg.item() + train_loss2 = train_loss2 + loss2_avg.item() + + outputs1_argmax = torch.argmax(outputs1, dim = 1, keepdim = True) + soft_out1 = get_soft_label(outputs1_argmax, class_num, self.tensor_type) + soft_out1, labels_prob = reshape_prediction_and_ground_truth(soft_out1, labels_prob) + dice_list = get_classwise_dice(soft_out1, labels_prob).detach().cpu().numpy() + train_dice_list.append(dice_list) + train_avg_loss_no_select1 = train_loss_no_select1 / iter_valid + train_avg_loss_no_select2 = train_loss_no_select2 / iter_valid + train_avg_loss1 = train_loss1 / iter_valid + train_avg_loss2 = train_loss2 / iter_valid + train_cls_dice = np.asarray(train_dice_list).mean(axis = 0) + train_avg_dice = train_cls_dice.mean() + + train_scalers = {'loss': (train_avg_loss1 + train_avg_loss2) / 2, + 'loss1':train_avg_loss1, 'loss2': train_avg_loss2, + 'loss_no_select1':train_avg_loss_no_select1, + 'loss_no_select2':train_avg_loss_no_select2, + 'select_ratio':remb_ratio, 'avg_dice':train_avg_dice, 'class_dice': train_cls_dice} + return train_scalers + + def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): + loss_scalar ={'train':train_scalars['loss'], + 'valid':valid_scalars['loss']} + loss_no_select_scalar = {'net1':train_scalars['loss_no_select1'], + 'net2':train_scalars['loss_no_select2']} + + dice_scalar ={'train':train_scalars['avg_dice'], 'valid':valid_scalars['avg_dice']} + self.summ_writer.add_scalars('loss', loss_scalar, glob_it) + self.summ_writer.add_scalars('loss_no_select', loss_no_select_scalar, glob_it) + self.summ_writer.add_scalars('select_ratio', {'select_ratio':train_scalars['select_ratio']}, glob_it) + self.summ_writer.add_scalars('lr', {"lr": lr_value}, glob_it) + self.summ_writer.add_scalars('dice', dice_scalar, glob_it) + class_num = self.config['network']['class_num'] + for c in range(class_num): + cls_dice_scalar = {'train':train_scalars['class_dice'][c], \ + 'valid':valid_scalars['class_dice'][c]} + self.summ_writer.add_scalars('class_{0:}_dice'.format(c), cls_dice_scalar, glob_it) + + logging.info('train loss {0:.4f}, avg dice {1:.4f} '.format( + train_scalars['loss'], train_scalars['avg_dice']) + "[" + \ + ' '.join("{0:.4f}".format(x) for x in train_scalars['class_dice']) + "]") + logging.info('valid loss {0:.4f}, avg dice {1:.4f} '.format( + valid_scalars['loss'], valid_scalars['avg_dice']) + "[" + \ + ' '.join("{0:.4f}".format(x) for x in valid_scalars['class_dice']) + "]") diff --git a/pymic/net_run_ssl/ssl_abstract.py b/pymic/net_run_ssl/ssl_abstract.py index f1b97ba..1d18c4d 100644 --- a/pymic/net_run_ssl/ssl_abstract.py +++ b/pymic/net_run_ssl/ssl_abstract.py @@ -12,7 +12,6 @@ from pymic.loss.seg.ssl import EntropyLoss from pymic.net_run.agent_seg import SegmentationAgent from pymic.transform.trans_dict import TransformDict -from pymic.util.ramps import sigmoid_rampup class SSLSegAgent(SegmentationAgent): """ @@ -70,9 +69,6 @@ def worker_init_fn(worker_id): batch_size = bn_train_unlab, shuffle=True, num_workers= num_worker, worker_init_fn=worker_init) - def training(self): - pass - def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): loss_scalar ={'train':train_scalars['loss'], 'valid':valid_scalars['loss']} diff --git a/pymic/net_run_ssl/ssl_cct.py b/pymic/net_run_ssl/ssl_cct.py index 80e9d07..d0c4f24 100644 --- a/pymic/net_run_ssl/ssl_cct.py +++ b/pymic/net_run_ssl/ssl_cct.py @@ -5,12 +5,12 @@ import torch.nn as nn import torch.nn.functional as F import numpy as np +from torch.optim import lr_scheduler from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.net_run_ssl.ssl_abstract import SSLSegAgent -from pymic.util.ramps import sigmoid_rampup -from pymic.util.general import keyword_match +from pymic.util.ramps import get_rampup_ratio def softmax_mse_loss(inputs, targets, conf_mask=False, threshold=None, use_softmax=False): assert inputs.requires_grad == True and targets.requires_grad == False @@ -73,6 +73,9 @@ def training(self): class_num = self.config['network']['class_num'] iter_valid = self.config['training']['iter_valid'] ssl_cfg = self.config['semi_supervised_learning'] + iter_max = self.config['training']['iter_max'] + rampup_start = ssl_cfg.get('rampup_start', 0) + rampup_end = ssl_cfg.get('rampup_end', iter_max) unsup_loss_name = ssl_cfg.get('unsupervised_loss', "MSE") self.unsup_loss_f = unsup_loss_dict[unsup_loss_name] train_loss = 0 @@ -118,20 +121,15 @@ def training(self): for p1_auxi in p1_aux: loss_reg += self.unsup_loss_f( p1_auxi, p1, use_softmax = True) loss_reg = loss_reg / len(p1_aux) - - iter_max = self.config['training']['iter_max'] - ramp_up_length = ssl_cfg.get('ramp_up_length', iter_max) - regular_w = 0.0 - if(self.glob_it > ssl_cfg.get('iter_sup', 0)): - regular_w = ssl_cfg.get('regularize_w', 0.1) - if(ramp_up_length is not None and self.glob_it < ramp_up_length): - regular_w = regular_w * sigmoid_rampup(self.glob_it, ramp_up_length) - + + rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") + regular_w = ssl_cfg.get('regularize_w', 0.1) * rampup_ratio loss = loss_sup + regular_w*loss_reg loss.backward() self.optimizer.step() - if(not keyword_match(self.config['training']['lr_scheduler'], "ReduceLROnPlateau")): + if(self.scheduler is not None and \ + not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): self.scheduler.step() train_loss = train_loss + loss.item() train_loss_sup = train_loss_sup + loss_sup.item() diff --git a/pymic/net_run_ssl/ssl_cps.py b/pymic/net_run_ssl/ssl_cps.py index e39bec9..2264d0d 100644 --- a/pymic/net_run_ssl/ssl_cps.py +++ b/pymic/net_run_ssl/ssl_cps.py @@ -3,15 +3,30 @@ import logging import numpy as np import torch -import torch.optim as optim +import torch.nn as nn +from torch.optim import lr_scheduler from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice -from pymic.net_run.get_optimizer import get_optimizer, get_lr_scheduler from pymic.net_run_ssl.ssl_abstract import SSLSegAgent from pymic.net.net_dict_seg import SegNetDict -from pymic.util.ramps import sigmoid_rampup -from pymic.util.general import keyword_match +from pymic.util.ramps import get_rampup_ratio + +class BiNet(nn.Module): + def __init__(self, params): + super(BiNet, self).__init__() + net_name = params['net_type'] + self.net1 = SegNetDict[net_name](params) + self.net2 = SegNetDict[net_name](params) + + def forward(self, x): + out1 = self.net1(x) + out2 = self.net2(x) + + if(self.training): + return out1, out2 + else: + return (out1 + out2) / 3 class SSLCPS(SSLSegAgent): """ @@ -23,48 +38,27 @@ class SSLCPS(SSLSegAgent): """ def __init__(self, config, stage = 'train'): super(SSLCPS, self).__init__(config, stage) - self.net2 = None - self.optimizer2 = None - self.scheduler2 = None def create_network(self): - super(SSLCPS, self).create_network() - if(self.net2 is None): - net_name = self.config['network']['net_type'] - if(net_name not in SegNetDict): - raise ValueError("Undefined network {0:}".format(net_name)) - self.net2 = SegNetDict[net_name](self.config['network']) + if(self.net is None): + self.net = BiNet(self.config['network']) if(self.tensor_type == 'float'): - self.net2.float() + self.net.float() else: - self.net2.double() - - def train_valid(self): - # create optimizor for the second network - opt_params = self.config['training'] - if(self.optimizer2 is None): - self.optimizer2 = get_optimizer(opt_params['optimizer'], - self.net2.parameters(), opt_params) - last_iter = -1 - # if(self.checkpoint is not None): - # self.optimizer2.load_state_dict(self.checkpoint['optimizer_state_dict']) - # last_iter = self.checkpoint['iteration'] - 1 - if(self.scheduler2 is None): - opt_params["laster_iter"] = last_iter - self.scheduler2 = get_lr_scheduler(self.optimizer2, opt_params) - super(SSLCPS, self).train_valid() + self.net.double() def training(self): class_num = self.config['network']['class_num'] iter_valid = self.config['training']['iter_valid'] ssl_cfg = self.config['semi_supervised_learning'] + iter_max = self.config['training']['iter_max'] + rampup_start = ssl_cfg.get('rampup_start', 0) + rampup_end = ssl_cfg.get('rampup_end', iter_max) train_loss = 0 train_loss_sup1, train_loss_pseudo_sup1 = 0, 0 train_loss_sup2, train_loss_pseudo_sup2 = 0, 0 train_dice_list = [] self.net.train() - self.net2.train() - self.net2.to(self.device) for it in range(iter_valid): try: data_lab = next(self.trainIter) @@ -86,9 +80,8 @@ def training(self): # zero the parameter gradients self.optimizer.zero_grad() - self.optimizer2.zero_grad() - outputs1, outputs2 = self.net(inputs), self.net2(inputs) + outputs1, outputs2 = self.net(inputs) outputs_soft1 = torch.softmax(outputs1, dim=1) outputs_soft2 = torch.softmax(outputs2, dim=1) @@ -106,13 +99,8 @@ def training(self): pse_sup1 = self.get_loss_value(data_unlab, outputs1[n0:], pse_prob2) pse_sup2 = self.get_loss_value(data_unlab, outputs2[n0:], pse_prob1) - iter_max = self.config['training']['iter_max'] - ramp_up_len = ssl_cfg.get('ramp_up_length', iter_max) - regular_w = 0.0 - if(self.glob_it > ssl_cfg.get('iter_sup', 0)): - regular_w = ssl_cfg.get('regularize_w', 0.1) - if(ramp_up_len is not None and self.glob_it < ramp_up_len): - regular_w = regular_w * sigmoid_rampup(self.glob_it, ramp_up_len) + rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") + regular_w = ssl_cfg.get('regularize_w', 0.1) * rampup_ratio model1_loss = loss_sup1 + regular_w * pse_sup1 model2_loss = loss_sup2 + regular_w * pse_sup2 @@ -120,10 +108,9 @@ def training(self): loss.backward() self.optimizer.step() - self.optimizer2.step() - if(not keyword_match(self.config['training']['lr_scheduler'], "ReduceLROnPlateau")): + if(self.scheduler is not None and \ + not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): self.scheduler.step() - self.scheduler2.step() train_loss = train_loss + loss.item() train_loss_sup1 = train_loss_sup1 + loss_sup1.item() @@ -152,13 +139,7 @@ def training(self): 'loss_pse_sup1':train_avg_loss_pse_sup1, 'loss_pse_sup2': train_avg_loss_pse_sup2, 'regular_w':regular_w, 'avg_dice':train_avg_dice, 'class_dice': train_cls_dice} return train_scalers - - def validation(self): - return_value = super(SSLCPS, self).validation() - if(keyword_match(self.config['training']['lr_scheduler'], "ReduceLROnPlateau")): - self.scheduler2.step(return_value['avg_dice']) - return return_value - + def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): loss_scalar ={'train':train_scalars['loss'], 'valid':valid_scalars['loss']} diff --git a/pymic/net_run_ssl/ssl_em.py b/pymic/net_run_ssl/ssl_em.py index e4b1f19..810a90c 100644 --- a/pymic/net_run_ssl/ssl_em.py +++ b/pymic/net_run_ssl/ssl_em.py @@ -3,14 +3,14 @@ import logging import numpy as np import torch +from torhc.optim import lr_scheduler from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.loss.seg.ssl import EntropyLoss from pymic.net_run_ssl.ssl_abstract import SSLSegAgent from pymic.transform.trans_dict import TransformDict -from pymic.util.ramps import sigmoid_rampup -from pymic.util.general import keyword_match +from pymic.util.ramps import get_rampup_ratio class SSLEntropyMinimization(SSLSegAgent): """ @@ -29,6 +29,9 @@ def training(self): class_num = self.config['network']['class_num'] iter_valid = self.config['training']['iter_valid'] ssl_cfg = self.config['semi_supervised_learning'] + iter_max = self.config['training']['iter_max'] + rampup_start = ssl_cfg.get('rampup_start', 0) + rampup_end = ssl_cfg.get('rampup_end', iter_max) train_loss = 0 train_loss_sup = 0 train_loss_reg = 0 @@ -64,18 +67,15 @@ def training(self): loss_dict = {"prediction":outputs, 'softmax':True} loss_reg = EntropyLoss()(loss_dict) - iter_max = self.config['training']['iter_max'] - ramp_up_length = ssl_cfg.get('ramp_up_length', iter_max) - regular_w = 0.0 - if(self.glob_it > ssl_cfg.get('iter_sup', 0)): - regular_w = ssl_cfg.get('regularize_w', 0.1) - if(ramp_up_length is not None and self.glob_it < ramp_up_length): - regular_w = regular_w * sigmoid_rampup(self.glob_it, ramp_up_length) + rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") + regular_w = ssl_cfg.get('regularize_w', 0.1) * rampup_ratio + loss = loss_sup + regular_w*loss_reg # if (self.config['training']['use']) loss.backward() self.optimizer.step() - if(not keyword_match(self.config['training']['lr_scheduler'], "ReduceLROnPlateau")): + if(self.scheduler is not None and \ + not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): self.scheduler.step() train_loss = train_loss + loss.item() diff --git a/pymic/net_run_ssl/ssl_mt.py b/pymic/net_run_ssl/ssl_mt.py index aa7fbff..0456726 100644 --- a/pymic/net_run_ssl/ssl_mt.py +++ b/pymic/net_run_ssl/ssl_mt.py @@ -3,13 +3,13 @@ import logging import torch import numpy as np +from torch.optim import lr_scheduler from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.net_run_ssl.ssl_abstract import SSLSegAgent from pymic.net.net_dict_seg import SegNetDict -from pymic.util.ramps import sigmoid_rampup -from pymic.util.general import keyword_match +from pymic.util.ramps import get_rampup_ratio class SSLMeanTeacher(SSLSegAgent): """ @@ -39,6 +39,9 @@ def training(self): class_num = self.config['network']['class_num'] iter_valid = self.config['training']['iter_valid'] ssl_cfg = self.config['semi_supervised_learning'] + iter_max = self.config['training']['iter_max'] + rampup_start = ssl_cfg.get('rampup_start', 0) + rampup_end = ssl_cfg.get('rampup_end', iter_max) train_loss = 0 train_loss_sup = 0 train_loss_reg = 0 @@ -82,19 +85,16 @@ def training(self): outputs_ema = self.net_ema(inputs_ema) p1_ema_soft = torch.softmax(outputs_ema, dim=1) - iter_max = self.config['training']['iter_max'] - ramp_up_length = ssl_cfg.get('ramp_up_length', iter_max) - regular_w = 0.0 - if(self.glob_it > ssl_cfg.get('iter_sup', 0)): - regular_w = ssl_cfg.get('regularize_w', 0.1) - if(ramp_up_length is not None and self.glob_it < ramp_up_length): - regular_w = regular_w * sigmoid_rampup(self.glob_it, ramp_up_length) + rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") + regular_w = ssl_cfg.get('regularize_w', 0.1) * rampup_ratio + loss_reg = torch.nn.MSELoss()(p1_soft, p1_ema_soft) loss = loss_sup + regular_w*loss_reg loss.backward() self.optimizer.step() - if(not keyword_match(self.config['training']['lr_scheduler'], "ReduceLROnPlateau")): + if(self.scheduler is not None and \ + not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): self.scheduler.step() # update EMA diff --git a/pymic/net_run_ssl/ssl_uamt.py b/pymic/net_run_ssl/ssl_uamt.py index 3352231..360dab1 100644 --- a/pymic/net_run_ssl/ssl_uamt.py +++ b/pymic/net_run_ssl/ssl_uamt.py @@ -3,12 +3,12 @@ import logging import torch import numpy as np +from torch.optim import lr_scheduler from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.net_run_ssl.ssl_mt import SSLMeanTeacher -from pymic.util.ramps import sigmoid_rampup -from pymic.util.general import keyword_match +from pymic.util.ramps import get_rampup_ratio class SSLUncertaintyAwareMeanTeacher(SSLMeanTeacher): """ @@ -22,6 +22,9 @@ def training(self): class_num = self.config['network']['class_num'] iter_valid = self.config['training']['iter_valid'] ssl_cfg = self.config['semi_supervised_learning'] + iter_max = self.config['training']['iter_max'] + rampup_start = ssl_cfg.get('rampup_start', 0) + rampup_end = ssl_cfg.get('rampup_end', iter_max) train_loss = 0 train_loss_sup = 0 train_loss_reg = 0 @@ -81,24 +84,19 @@ def training(self): uncertainty = -1.0 * torch.sum(preds*torch.log(preds + 1e-6), dim=1, keepdim=True) - iter_max = self.config['training']['iter_max'] - ramp_up_length = ssl_cfg.get('ramp_up_length', iter_max) - threshold_ramp = sigmoid_rampup(self.glob_it, iter_max) + rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") class_num = list(y0.shape)[1] - threshold = (0.75+0.25*threshold_ramp)*np.log(class_num) + threshold = (0.75+0.25*rampup_ratio)*np.log(class_num) mask = (uncertainty < threshold).float() loss_reg = torch.sum(mask*square_error)/(2*torch.sum(mask)+1e-16) - regular_w = 0.0 - if(self.glob_it > ssl_cfg.get('iter_sup', 0)): - regular_w = ssl_cfg.get('regularize_w', 0.1) - if(ramp_up_length is not None and self.glob_it < ramp_up_length): - regular_w = regular_w * sigmoid_rampup(self.glob_it, ramp_up_length) + regular_w = ssl_cfg.get('regularize_w', 0.1) * rampup_ratio loss = loss_sup + regular_w*loss_reg loss.backward() self.optimizer.step() - if(not keyword_match(self.config['training']['lr_scheduler'], "ReduceLROnPlateau")): + if(self.scheduler is not None and \ + not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): self.scheduler.step() diff --git a/pymic/net_run_ssl/ssl_urpc.py b/pymic/net_run_ssl/ssl_urpc.py index 5d269fd..d0179cd 100644 --- a/pymic/net_run_ssl/ssl_urpc.py +++ b/pymic/net_run_ssl/ssl_urpc.py @@ -4,12 +4,12 @@ import torch import torch.nn as nn import numpy as np +from torhc.optim import lr_scheduler from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.net_run_ssl.ssl_abstract import SSLSegAgent -from pymic.util.ramps import sigmoid_rampup -from pymic.util.general import keyword_match +from pymic.util.ramps import get_rampup_ratio class SSLURPC(SSLSegAgent): """ @@ -24,6 +24,9 @@ def training(self): class_num = self.config['network']['class_num'] iter_valid = self.config['training']['iter_valid'] ssl_cfg = self.config['semi_supervised_learning'] + iter_max = self.config['training']['iter_max'] + rampup_start = ssl_cfg.get('rampup_start', 0) + rampup_end = ssl_cfg.get('rampup_end', iter_max) train_loss = 0 train_loss_sup = 0 train_loss_reg = 0 @@ -78,19 +81,14 @@ def training(self): loss_reg += loss_i loss_reg = loss_reg / len(outputs_list) - iter_max = self.config['training']['iter_max'] - ramp_up_length = ssl_cfg.get('ramp_up_length', iter_max) - regular_w = 0.0 - if(self.glob_it > ssl_cfg.get('iter_sup', 0)): - regular_w = ssl_cfg.get('regularize_w', 0.1) - if(ramp_up_length is not None and self.glob_it < ramp_up_length): - regular_w = regular_w * sigmoid_rampup(self.glob_it, ramp_up_length) - + rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") + regular_w = ssl_cfg.get('regularize_w', 0.1) * rampup_ratio loss = loss_sup + regular_w*loss_reg loss.backward() self.optimizer.step() - if(not keyword_match(self.config['training']['lr_scheduler'], "ReduceLROnPlateau")): + if(self.scheduler is not None and \ + not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): self.scheduler.step() train_loss = train_loss + loss.item() train_loss_sup = train_loss_sup + loss_sup.item() diff --git a/pymic/net_run_wsl/wsl_dmpls.py b/pymic/net_run_wsl/wsl_dmpls.py index 1e60e47..a198ddc 100644 --- a/pymic/net_run_wsl/wsl_dmpls.py +++ b/pymic/net_run_wsl/wsl_dmpls.py @@ -4,13 +4,13 @@ import numpy as np import random import torch +from torhc.optim import lr_scheduler from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.loss.seg.dice import DiceLoss from pymic.net_run_wsl.wsl_abstract import WSLSegAgent -from pymic.util.ramps import sigmoid_rampup -from pymic.util.general import keyword_match +from pymic.util.ramps import get_rampup_ratio class WSLDMPLS(WSLSegAgent): """ @@ -32,6 +32,9 @@ def training(self): class_num = self.config['network']['class_num'] iter_valid = self.config['training']['iter_valid'] wsl_cfg = self.config['weakly_supervised_learning'] + iter_max = self.config['training']['iter_max'] + rampup_start = wsl_cfg.get('rampup_start', 0) + rampup_end = wsl_cfg.get('rampup_end', iter_max) train_loss = 0 train_loss_sup = 0 train_loss_reg = 0 @@ -73,18 +76,14 @@ def training(self): loss_dict2 = {"prediction":outputs2, 'ground_truth':pseudo_lab} loss_reg = 0.5 * (loss_calculator(loss_dict1) + loss_calculator(loss_dict2)) - iter_max = self.config['training']['iter_max'] - ramp_up_length = wsl_cfg.get('ramp_up_length', iter_max) - regular_w = 0.0 - if(self.glob_it > wsl_cfg.get('iter_sup', 0)): - regular_w = wsl_cfg.get('regularize_w', 0.1) - if(ramp_up_length is not None and self.glob_it < ramp_up_length): - regular_w = regular_w * sigmoid_rampup(self.glob_it, ramp_up_length) + rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") + regular_w = wsl_cfg.get('regularize_w', 0.1) * rampup_ratio loss = loss_sup + regular_w*loss_reg loss.backward() self.optimizer.step() - if(not keyword_match(self.config['training']['lr_scheduler'], "ReduceLROnPlateau")): + if(self.scheduler is not None and \ + not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): self.scheduler.step() train_loss = train_loss + loss.item() diff --git a/pymic/net_run_wsl/wsl_em.py b/pymic/net_run_wsl/wsl_em.py index 66823c1..3b2d595 100644 --- a/pymic/net_run_wsl/wsl_em.py +++ b/pymic/net_run_wsl/wsl_em.py @@ -3,18 +3,18 @@ import logging import numpy as np import torch +from torch.optim import lr_scheduler from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.loss.seg.ssl import EntropyLoss from pymic.net_run.agent_seg import SegmentationAgent from pymic.net_run_wsl.wsl_abstract import WSLSegAgent -from pymic.util.ramps import sigmoid_rampup -from pymic.util.general import keyword_match +from pymic.util.ramps import get_rampup_ratio class WSLEntropyMinimization(WSLSegAgent): """ - Training and testing agent for semi-supervised segmentation + Weakly suepervised segmentation with Entropy Minimization Regularization. """ def __init__(self, config, stage = 'train'): super(WSLEntropyMinimization, self).__init__(config, stage) @@ -23,6 +23,9 @@ def training(self): class_num = self.config['network']['class_num'] iter_valid = self.config['training']['iter_valid'] wsl_cfg = self.config['weakly_supervised_learning'] + iter_max = self.config['training']['iter_max'] + rampup_start = wsl_cfg.get('rampup_start', 0) + rampup_end = wsl_cfg.get('rampup_end', iter_max) train_loss = 0 train_loss_sup = 0 train_loss_reg = 0 @@ -50,18 +53,14 @@ def training(self): loss_dict= {"prediction":outputs, 'softmax':True} loss_reg = EntropyLoss()(loss_dict) - iter_max = self.config['training']['iter_max'] - ramp_up_length = wsl_cfg.get('ramp_up_length', iter_max) - regular_w = 0.0 - if(self.glob_it > wsl_cfg.get('iter_sup', 0)): - regular_w = wsl_cfg.get('regularize_w', 0.1) - if(ramp_up_length is not None and self.glob_it < ramp_up_length): - regular_w = regular_w * sigmoid_rampup(self.glob_it, ramp_up_length) + rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") + regular_w = wsl_cfg.get('regularize_w', 0.1) * rampup_ratio loss = loss_sup + regular_w*loss_reg loss.backward() self.optimizer.step() - if(not keyword_match(self.config['training']['lr_scheduler'], "ReduceLROnPlateau")): + if(self.scheduler is not None and \ + not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): self.scheduler.step() train_loss = train_loss + loss.item() diff --git a/pymic/net_run_wsl/wsl_gatedcrf.py b/pymic/net_run_wsl/wsl_gatedcrf.py index 2ae8318..2be8856 100644 --- a/pymic/net_run_wsl/wsl_gatedcrf.py +++ b/pymic/net_run_wsl/wsl_gatedcrf.py @@ -3,13 +3,13 @@ import logging import numpy as np import torch +from torhc.optim import lr_scheduler from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.loss.seg.gatedcrf import ModelLossSemsegGatedCRF from pymic.net_run_wsl.wsl_abstract import WSLSegAgent -from pymic.util.ramps import sigmoid_rampup -from pymic.util.general import keyword_match +from pymic.util.ramps import get_rampup_ratio class WSLGatedCRF(WSLSegAgent): """ @@ -38,6 +38,9 @@ def training(self): class_num = self.config['network']['class_num'] iter_valid = self.config['training']['iter_valid'] wsl_cfg = self.config['weakly_supervised_learning'] + iter_max = self.config['training']['iter_max'] + rampup_start = wsl_cfg.get('rampup_start', 0) + rampup_end = wsl_cfg.get('rampup_end', iter_max) train_loss = 0 train_loss_sup = 0 train_loss_reg = 0 @@ -81,18 +84,14 @@ def training(self): loss_reg = gatecrf_loss(outputs_soft, self.kernels, self.radius, batch_dict,input_shape[-2], input_shape[-1])["loss"] - iter_max = self.config['training']['iter_max'] - ramp_up_length = wsl_cfg.get('ramp_up_length', iter_max) - regular_w = 0.0 - if(self.glob_it > wsl_cfg.get('iter_sup', 0)): - regular_w = wsl_cfg.get('regularize_w', 0.1) - if(ramp_up_length is not None and self.glob_it < ramp_up_length): - regular_w = regular_w * sigmoid_rampup(self.glob_it, ramp_up_length) + rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") + regular_w = wsl_cfg.get('regularize_w', 0.1) * rampup_ratio loss = loss_sup + regular_w*loss_reg loss.backward() self.optimizer.step() - if(not keyword_match(self.config['training']['lr_scheduler'], "ReduceLROnPlateau")): + if(self.scheduler is not None and \ + not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): self.scheduler.step() train_loss = train_loss + loss.item() diff --git a/pymic/net_run_wsl/wsl_mumford_shah.py b/pymic/net_run_wsl/wsl_mumford_shah.py index 095a0f6..df4c68f 100644 --- a/pymic/net_run_wsl/wsl_mumford_shah.py +++ b/pymic/net_run_wsl/wsl_mumford_shah.py @@ -3,13 +3,13 @@ import logging import numpy as np import torch +from torch.optim import lr_scheduler from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.loss.seg.mumford_shah import MumfordShahLoss from pymic.net_run_wsl.wsl_abstract import WSLSegAgent -from pymic.util.ramps import sigmoid_rampup -from pymic.util.general import keyword_match +from pymic.util.ramps import get_rampup_ratio class WSLMumfordShah(WSLSegAgent): """ @@ -25,6 +25,9 @@ def training(self): class_num = self.config['network']['class_num'] iter_valid = self.config['training']['iter_valid'] wsl_cfg = self.config['weakly_supervised_learning'] + iter_max = self.config['training']['iter_max'] + rampup_start = wsl_cfg.get('rampup_start', 0) + rampup_end = wsl_cfg.get('rampup_end', iter_max) train_loss = 0 train_loss_sup = 0 train_loss_reg = 0 @@ -54,18 +57,14 @@ def training(self): loss_dict = {"prediction":outputs, 'image':inputs} loss_reg = reg_loss_calculator(loss_dict) - iter_max = self.config['training']['iter_max'] - ramp_up_length = wsl_cfg.get('ramp_up_length', iter_max) - regular_w = 0.0 - if(self.glob_it > wsl_cfg.get('iter_sup', 0)): - regular_w = wsl_cfg.get('regularize_w', 0.1) - if(ramp_up_length is not None and self.glob_it < ramp_up_length): - regular_w = regular_w * sigmoid_rampup(self.glob_it, ramp_up_length) + rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") + regular_w = wsl_cfg.get('regularize_w', 0.1) * rampup_ratio loss = loss_sup + regular_w*loss_reg # if (self.config['training']['use']) loss.backward() self.optimizer.step() - if(not keyword_match(self.config['training']['lr_scheduler'], "ReduceLROnPlateau")): + if(self.scheduler is not None and \ + not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): self.scheduler.step() train_loss = train_loss + loss.item() diff --git a/pymic/net_run_wsl/wsl_tv.py b/pymic/net_run_wsl/wsl_tv.py index fde1c10..2e56cb4 100644 --- a/pymic/net_run_wsl/wsl_tv.py +++ b/pymic/net_run_wsl/wsl_tv.py @@ -3,12 +3,13 @@ import logging import numpy as np import torch +from torch.optim import lr_scheduler from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.loss.seg.ssl import TotalVariationLoss from pymic.net_run_wsl.wsl_abstract import WSLSegAgent -from pymic.util.ramps import sigmoid_rampup +from pymic.util.ramps import get_rampup_ratio from pymic.util.general import keyword_match class WSLTotalVariation(WSLSegAgent): @@ -22,6 +23,9 @@ def training(self): class_num = self.config['network']['class_num'] iter_valid = self.config['training']['iter_valid'] wsl_cfg = self.config['weakly_supervised_learning'] + iter_max = self.config['training']['iter_max'] + rampup_start = wsl_cfg.get('rampup_start', 0) + rampup_end = wsl_cfg.get('rampup_end', iter_max) train_loss = 0 train_loss_sup = 0 train_loss_reg = 0 @@ -49,18 +53,14 @@ def training(self): loss_dict = {"prediction":outputs, 'softmax':True} loss_reg = TotalVariationLoss()(loss_dict) - iter_max = self.config['training']['iter_max'] - ramp_up_length = wsl_cfg.get('ramp_up_length', iter_max) - regular_w = 0.0 - if(self.glob_it > wsl_cfg.get('iter_sup', 0)): - regular_w = wsl_cfg.get('regularize_w', 0.1) - if(ramp_up_length is not None and self.glob_it < ramp_up_length): - regular_w = regular_w * sigmoid_rampup(self.glob_it, ramp_up_length) + rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") + regular_w = wsl_cfg.get('regularize_w', 0.1) * rampup_ratio loss = loss_sup + regular_w*loss_reg # if (self.config['training']['use']) loss.backward() self.optimizer.step() - if(not keyword_match(self.config['training']['lr_scheduler'], "ReduceLROnPlateau")): + if(self.scheduler is not None and \ + not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): self.scheduler.step() train_loss = train_loss + loss.item() diff --git a/pymic/net_run_wsl/wsl_ustm.py b/pymic/net_run_wsl/wsl_ustm.py index 6083069..0a2f7e1 100644 --- a/pymic/net_run_wsl/wsl_ustm.py +++ b/pymic/net_run_wsl/wsl_ustm.py @@ -5,12 +5,13 @@ import random import torch import torch.nn.functional as F +from torch.optim import lr_scheduler from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.net.net_dict_seg import SegNetDict from pymic.net_run_wsl.wsl_abstract import WSLSegAgent -from pymic.util.ramps import sigmoid_rampup +from pymic.util.ramps import get_rampup_ratio from pymic.util.general import keyword_match class WSLUSTM(WSLSegAgent): @@ -42,6 +43,9 @@ def training(self): class_num = self.config['network']['class_num'] iter_valid = self.config['training']['iter_valid'] wsl_cfg = self.config['weakly_supervised_learning'] + iter_max = self.config['training']['iter_max'] + rampup_start = wsl_cfg.get('rampup_start', 0) + rampup_end = wsl_cfg.get('rampup_end', iter_max) train_loss = 0 train_loss_sup = 0 train_loss_reg = 0 @@ -97,24 +101,19 @@ def training(self): uncertainty = -1.0 * torch.sum(preds*torch.log(preds + 1e-6), dim=1, keepdim=True) - iter_max = self.config['training']['iter_max'] - ramp_up_length = wsl_cfg.get('ramp_up_length', iter_max) - threshold_ramp = sigmoid_rampup(self.glob_it, iter_max) + rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") class_num = list(y.shape)[1] - threshold = (0.75+0.25*threshold_ramp)*np.log(class_num) + threshold = (0.75+0.25*rampup_ratio)*np.log(class_num) mask = (uncertainty < threshold).float() loss_reg = torch.sum(mask*square_error)/(2*torch.sum(mask)+1e-16) - regular_w = 0.0 - if(self.glob_it > wsl_cfg.get('iter_sup', 0)): - regular_w = wsl_cfg.get('regularize_w', 0.1) - if(ramp_up_length is not None and self.glob_it < ramp_up_length): - regular_w = regular_w * sigmoid_rampup(self.glob_it, ramp_up_length) + regular_w = wsl_cfg.get('regularize_w', 0.1) * rampup_ratio loss = loss_sup + regular_w*loss_reg loss.backward() self.optimizer.step() - if(not keyword_match(self.config['training']['lr_scheduler'], "ReduceLROnPlateau")): + if(self.scheduler is not None and \ + not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): self.scheduler.step() # update EMA diff --git a/pymic/util/ramps.py b/pymic/util/ramps.py index e344cfe..b58adb6 100644 --- a/pymic/util/ramps.py +++ b/pymic/util/ramps.py @@ -10,24 +10,21 @@ 0 and 1. """ -def sigmoid_rampup(i, length): - """Exponential rampup from https://arxiv.org/abs/1610.02242""" - if length == 0: - return 1.0 - else: - i = np.clip(i, 0.0, length) - phase = 1.0 - (i + 0.0) / length - return float(np.exp(-5.0 * phase * phase)) - -def linear_rampup(i, length): - """Linear rampup""" - assert i >= 0 and length >= 0 - i = np.clip(i, 0.0, length) - return (i + 0.0) / length +def get_rampup_ratio(i, start, end, mode = "linear"): + if( i < start): + rampup = 0.0 + elif(i > end): + rampup = 1.0 + elif(mode == "linear"): + rampup = (i - start) / (end - start) + elif(mode == "sigmoid"): + phase = 1.0 - (i - start) / (end - start) + rampup = float(np.exp(-5.0 * phase * phase)) + return rampup -def cosine_rampdown(i, length): +def cosine_rampdown(i, start, end): """Cosine rampdown from https://arxiv.org/abs/1608.03983""" i = np.clip(i, 0.0, length) return float(.5 * (np.cos(np.pi * i / length) + 1)) \ No newline at end of file From 04567f1cc4c1abf23dc97c83a8f5168c260ead2c Mon Sep 17 00:00:00 2001 From: taigw Date: Thu, 18 Aug 2022 16:00:45 +0800 Subject: [PATCH 20/26] update log name update log name --- pymic/net_run/net_run.py | 4 ++-- pymic/net_run_nll/nll_main.py | 2 +- pymic/net_run_ssl/ssl_main.py | 2 +- pymic/net_run_wsl/wsl_main.py | 2 +- pymic/util/average_model.py | 1 + 5 files changed, 6 insertions(+), 5 deletions(-) diff --git a/pymic/net_run/net_run.py b/pymic/net_run/net_run.py index 4ec1ce7..971af7e 100644 --- a/pymic/net_run/net_run.py +++ b/pymic/net_run/net_run.py @@ -18,8 +18,8 @@ def main(): config = synchronize_config(config) log_dir = config['training']['ckpt_save_dir'] if(not os.path.exists(log_dir)): - os.mkdir(log_dir) - logging.basicConfig(filename=log_dir+"/log.txt", level=logging.INFO, + os.makedirs(log_dir, exist_ok=True) + logging.basicConfig(filename=log_dir+"/log_{0:}.txt".format(stage), level=logging.INFO, format='%(message)s') logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) logging_config(config) diff --git a/pymic/net_run_nll/nll_main.py b/pymic/net_run_nll/nll_main.py index d1ae7a1..7d8f1f8 100644 --- a/pymic/net_run_nll/nll_main.py +++ b/pymic/net_run_nll/nll_main.py @@ -23,7 +23,7 @@ def main(): log_dir = config['training']['ckpt_save_dir'] if(not os.path.exists(log_dir)): os.mkdir(log_dir) - logging.basicConfig(filename=log_dir+"/log.txt", level=logging.INFO, + logging.basicConfig(filename=log_dir+"/log_{0:}.txt".format(stage), level=logging.INFO, format='%(message)s') logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) logging_config(config) diff --git a/pymic/net_run_ssl/ssl_main.py b/pymic/net_run_ssl/ssl_main.py index 6bddf29..d904ab1 100644 --- a/pymic/net_run_ssl/ssl_main.py +++ b/pymic/net_run_ssl/ssl_main.py @@ -32,7 +32,7 @@ def main(): log_dir = config['training']['ckpt_save_dir'] if(not os.path.exists(log_dir)): os.mkdir(log_dir) - logging.basicConfig(filename=log_dir+"/log.txt", level=logging.INFO, + logging.basicConfig(filename=log_dir+"/log_{0:}.txt".format(stage), level=logging.INFO, format='%(message)s') logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) logging_config(config) diff --git a/pymic/net_run_wsl/wsl_main.py b/pymic/net_run_wsl/wsl_main.py index 916e1d8..abedb6b 100644 --- a/pymic/net_run_wsl/wsl_main.py +++ b/pymic/net_run_wsl/wsl_main.py @@ -31,7 +31,7 @@ def main(): log_dir = config['training']['ckpt_save_dir'] if(not os.path.exists(log_dir)): os.mkdir(log_dir) - logging.basicConfig(filename=log_dir+"/log.txt", level=logging.INFO, + logging.basicConfig(filename=log_dir+"/log_{0:}.txt".format(stage), level=logging.INFO, format='%(message)s') logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) logging_config(config) diff --git a/pymic/util/average_model.py b/pymic/util/average_model.py index 0b6fb29..73a537f 100644 --- a/pymic/util/average_model.py +++ b/pymic/util/average_model.py @@ -1,3 +1,4 @@ + import torch checkpoint_name1 = "/home/guotai/projects/PyMIC/examples/brats/model/casecade/wt/unet3d_4_8000.pt" From 137c7ebec02482761a70a02db345c9e3a9ec935e Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 19 Aug 2022 11:57:36 +0800 Subject: [PATCH 21/26] fix typo --- pymic/net/net3d/unet3d.py | 3 +- pymic/net_run/agent_seg.py | 4 +- pymic/net_run/net_run.py | 2 +- pymic/net_run_nll/nll_clslsr.py | 191 ++++++++++++++++++++++++++++++++ pymic/net_run_ssl/ssl_em.py | 2 +- pymic/net_run_ssl/ssl_urpc.py | 2 +- 6 files changed, 197 insertions(+), 7 deletions(-) create mode 100644 pymic/net_run_nll/nll_clslsr.py diff --git a/pymic/net/net3d/unet3d.py b/pymic/net/net3d/unet3d.py index 058cb79..fdedf4d 100644 --- a/pymic/net/net3d/unet3d.py +++ b/pymic/net/net3d/unet3d.py @@ -96,7 +96,6 @@ def __init__(self, params): self.n_class = self.params['class_num'] self.trilinear = self.params['trilinear'] self.deep_sup = self.params['deep_supervise'] - self.stage = self.params['stage'] assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) self.in_conv= ConvBlock(self.in_chns, self.ft_chns[0], self.dropout[0]) @@ -134,7 +133,7 @@ def forward(self, x): x_d1 = self.up3(x_d2, x1) x_d0 = self.up4(x_d1, x0) output = self.out_conv(x_d0) - if(self.deep_sup and self.stage == "train"): + if(self.deep_sup): out_shape = list(output.shape)[2:] output1 = self.out_conv1(x_d1) output1 = interpolate(output1, out_shape, mode = 'trilinear') diff --git a/pymic/net_run/agent_seg.py b/pymic/net_run/agent_seg.py index d8e74a2..5a53f3e 100644 --- a/pymic/net_run/agent_seg.py +++ b/pymic/net_run/agent_seg.py @@ -307,7 +307,7 @@ def train_valid(self): elif(isinstance(iter_save, (tuple, list))): iter_save_list = iter_save else: - iter_save_list = range(iter_start, iter_max + 1, iter_save) + iter_save_list = range(0, iter_max + 1, iter_save) self.max_val_dice = 0.0 self.max_val_it = 0 @@ -519,7 +519,7 @@ def save_ouputs(self, data): filename_replace_source = self.config['testing'].get('filename_replace_source', None) filename_replace_target = self.config['testing'].get('filename_replace_target', None) if(not os.path.exists(output_dir)): - os.mkdir(output_dir) + os.makedirs(output_dir, exist_ok=True) names, pred = data['names'], data['predict'] if(isinstance(pred, (list, tuple))): diff --git a/pymic/net_run/net_run.py b/pymic/net_run/net_run.py index 971af7e..4c953ad 100644 --- a/pymic/net_run/net_run.py +++ b/pymic/net_run/net_run.py @@ -10,7 +10,7 @@ def main(): if(len(sys.argv) < 3): print('Number of arguments should be 3. e.g.') - print(' pymic_net_run train config.cfg') + print(' pymic_run train config.cfg') exit() stage = str(sys.argv[1]) cfg_file = str(sys.argv[2]) diff --git a/pymic/net_run_nll/nll_clslsr.py b/pymic/net_run_nll/nll_clslsr.py new file mode 100644 index 0000000..2894db6 --- /dev/null +++ b/pymic/net_run_nll/nll_clslsr.py @@ -0,0 +1,191 @@ +# -*- coding: utf-8 -*- +""" +Caculating the confidence map of labels of training samples, +which is used in the method of SLSR. + Minqing Zhang et al., Characterizing Label Errors: Confident Learning + for Noisy-Labeled Image Segmentation, MICCAI 2020. +""" + +from __future__ import print_function, division +import cleanlab +import logging +import os +import scipy +import sys +import torch +import numpy as np +import pandas as pd +import torch.nn as nn +import torchvision.transforms as transforms +from PIL import Image +from pymic.io.nifty_dataset import NiftyDataset +from pymic.transform.trans_dict import TransformDict +from pymic.util.parse_config import * +from pymic.net_run.agent_seg import SegmentationAgent +from pymic.net_run.infer_func import Inferer + +def get_confident_map(gt, pred, CL_type = 'both'): + """ + gt: ground truth label (one-hot) with shape of NXC + pred: digit prediction of network with shape of NXC + """ + prob = scipy.special.softmax(pred, axis = 1) + if CL_type in ['both', 'Qij']: + noise = cleanlab.pruning.get_noise_indices(gt, prob, prune_method='both', n_jobs=1) + elif CL_type == 'Cij': + noise = cleanlab.pruning.get_noise_indices(gt, pred, prune_method='both', n_jobs=1) + elif CL_type == 'intersection': + noise_qij = cleanlab.pruning.get_noise_indices(gt, prob, prune_method='both', n_jobs=1) + noise_cij = cleanlab.pruning.get_noise_indices(gt, pred, prune_method='both', n_jobs=1) + noise = noise_qij & noise_cij + elif CL_type == 'union': + noise_qij = cleanlab.pruning.get_noise_indices(gt, prob, prune_method='both', n_jobs=1) + noise_cij = cleanlab.pruning.get_noise_indices(gt, pred, prune_method='both', n_jobs=1) + noise = noise_qij | noise_cij + elif CL_type in ['prune_by_class', 'prune_by_noise_rate']: + noise = cleanlab.pruning.get_noise_indices(gt, prob, prune_method=CL_type, n_jobs=1) + return noise + +class NLLCLSLSR(SegmentationAgent): + def __init__(self, config, stage = 'test'): + super(NLLCLSLSR, self).__init__(config, stage) + + def infer_with_cl(self): + device_ids = self.config['testing']['gpus'] + device = torch.device("cuda:{0:}".format(device_ids[0])) + self.net.to(device) + + if(self.config['testing'].get('evaluation_mode', True)): + self.net.eval() + if(self.config['testing'].get('test_time_dropout', False)): + def test_time_dropout(m): + if(type(m) == nn.Dropout): + logging.info('dropout layer') + m.train() + self.net.apply(test_time_dropout) + + ckpt_mode = self.config['testing']['ckpt_mode'] + ckpt_name = self.get_checkpoint_name() + if(ckpt_mode == 3): + assert(isinstance(ckpt_name, (tuple, list))) + self.infer_with_multiple_checkpoints() + return + else: + if(isinstance(ckpt_name, (tuple, list))): + raise ValueError("ckpt_mode should be 3 if ckpt_name is a list") + + # load network parameters and set the network as evaluation mode + checkpoint = torch.load(ckpt_name, map_location = device) + self.net.load_state_dict(checkpoint['model_state_dict']) + + if(self.inferer is None): + infer_cfg = self.config['testing'] + class_num = self.config['network']['class_num'] + infer_cfg['class_num'] = class_num + self.inferer = Inferer(infer_cfg) + pred_list = [] + gt_list = [] + filename_list = [] + with torch.no_grad(): + for data in self.test_loader: + images = self.convert_tensor_type(data['image']) + labels = self.convert_tensor_type(data['label_prob']) + names = data['names'] + filename_list.append(names) + images = images.to(device) + + pred = self.inferer.run(self.net, images) + # convert tensor to numpy + if(isinstance(pred, (tuple, list))): + pred = [item.cpu().numpy() for item in pred] + else: + pred = pred.cpu().numpy() + data['predict'] = pred + # inverse transform + for transform in self.transform_list[::-1]: + if (transform.inverse): + data = transform.inverse_transform_for_prediction(data) + + pred = data['predict'] + # conver prediction from N, C, H, W to (N*H*W)*C + print(names, pred.shape, labels.shape) + pred_2d = np.swapaxes(pred, 1, 2) + pred_2d = np.swapaxes(pred_2d, 2, 3) + pred_2d = pred_2d.reshape(-1, class_num) + lab = labels.cpu().numpy() + lab_2d = np.swapaxes(lab, 1, 2) + lab_2d = np.swapaxes(lab_2d, 2, 3) + lab_2d = lab_2d.reshape(-1, class_num) + pred_list.append(pred_2d) + gt_list.append(lab_2d) + + pred_cat = np.concatenate(pred_list) + gt_cat = np.concatenate(gt_list) + gt = np.argmax(gt_cat, axis = 1) + gt = gt.reshape(-1).astype(np.uint8) + print(gt.shape, pred_cat.shape) + conf = get_confident_map(gt, pred_cat) + conf = conf.reshape(-1, 256, 256).astype(np.uint8) * 255 + save_dir = self.config['dataset']['root_dir'] + "/slsr_conf" + for idx in range(len(filename_list)): + filename = filename_list[idx][0].split('/')[-1] + conf_map = Image.fromarray(conf[idx]) + dst_path = os.path.join(save_dir, filename) + conf_map.save(dst_path) + +def get_confidence_map(): + if(len(sys.argv) < 2): + print('Number of arguments should be 3. e.g.') + print(' python nll_cl.py config.cfg') + exit() + cfg_file = str(sys.argv[1]) + config = parse_config(cfg_file) + config = synchronize_config(config) + + # set dataset + transform_names = config['dataset']['valid_transform'] + transform_list = [] + transform_dict = TransformDict + if(transform_names is None or len(transform_names) == 0): + data_transform = None + else: + transform_param = config['dataset'] + transform_param['task'] = 'segmentation' + for name in transform_names: + if(name not in transform_dict): + raise(ValueError("Undefined transform {0:}".format(name))) + one_transform = transform_dict[name](transform_param) + transform_list.append(one_transform) + data_transform = transforms.Compose(transform_list) + print('transform list', transform_list) + csv_file = config['dataset']['train_csv'] + modal_num = config['dataset'].get('modal_num', 1) + dataset = NiftyDataset(root_dir = config['dataset']['root_dir'], + csv_file = csv_file, + modal_num = modal_num, + with_label= True, + transform = data_transform ) + + agent = NLLCLSLSR(config, 'test') + agent.set_datasets(None, None, dataset) + agent.transform_list = transform_list + agent.create_dataset() + agent.create_network() + agent.infer_with_cl() + + # create training csv for confidence learning + df_train = pd.read_csv(csv_file) + pixel_weight = [] + for i in range(len(df_train["label"])): + lab_name = df_train["label"][i].split('/')[-1] + weight_name = "slsr_conf/" + lab_name + pixel_weight.append(weight_name) + train_cl_dict = {"image": df_train["image"], + "pixel_weight": pixel_weight, + "label": df_train["label"]} + train_cl_csv = csv_file.replace(".csv", "_clslsr.csv") + df_cl = pd.DataFrame.from_dict(train_cl_dict) + df_cl.to_csv(train_cl_csv, index = False) + +if __name__ == "__main__": + get_confidence_map() \ No newline at end of file diff --git a/pymic/net_run_ssl/ssl_em.py b/pymic/net_run_ssl/ssl_em.py index 810a90c..49dd22f 100644 --- a/pymic/net_run_ssl/ssl_em.py +++ b/pymic/net_run_ssl/ssl_em.py @@ -3,7 +3,7 @@ import logging import numpy as np import torch -from torhc.optim import lr_scheduler +from torch.optim import lr_scheduler from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice diff --git a/pymic/net_run_ssl/ssl_urpc.py b/pymic/net_run_ssl/ssl_urpc.py index d0179cd..20b3d84 100644 --- a/pymic/net_run_ssl/ssl_urpc.py +++ b/pymic/net_run_ssl/ssl_urpc.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn import numpy as np -from torhc.optim import lr_scheduler +from torch.optim import lr_scheduler from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice From 303e624b80e55ddd6d478a67311f299e32b6e8a3 Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 19 Aug 2022 16:45:45 +0800 Subject: [PATCH 22/26] add dast for nll add dast for nll set the output model of dual branch network --- pymic/loss/seg/ce.py | 2 +- pymic/net/net2d/unet2d_dual_branch.py | 51 +---- pymic/net_run_nll/nll_dast.py | 260 ++++++++++++++++++++++++++ pymic/net_run_nll/nll_main.py | 4 +- 4 files changed, 271 insertions(+), 46 deletions(-) create mode 100644 pymic/net_run_nll/nll_dast.py diff --git a/pymic/loss/seg/ce.py b/pymic/loss/seg/ce.py index da2bf14..cdef1a0 100644 --- a/pymic/loss/seg/ce.py +++ b/pymic/loss/seg/ce.py @@ -6,7 +6,7 @@ from pymic.loss.seg.util import reshape_tensor_to_2D class CrossEntropyLoss(nn.Module): - def __init__(self, params): + def __init__(self, params = None): super(CrossEntropyLoss, self).__init__() if(params is None): self.softmax = True diff --git a/pymic/net/net2d/unet2d_dual_branch.py b/pymic/net/net2d/unet2d_dual_branch.py index 9622bd0..3531c89 100644 --- a/pymic/net/net2d/unet2d_dual_branch.py +++ b/pymic/net/net2d/unet2d_dual_branch.py @@ -16,6 +16,7 @@ class UNet2D_DualBranch(nn.Module): def __init__(self, params): super(UNet2D_DualBranch, self).__init__() + self.output_mode = params.get("output_mode", "average") self.encoder = Encoder(params) self.decoder1 = Decoder(params) self.decoder2 = Decoder(params) @@ -41,47 +42,9 @@ def forward(self, x): if(self.training): return output1, output2 else: - return (output1 + output2)/2 - # for backup -class DualBranchUNet2D(UNet2D): - def __init__(self, params): - params['deep_supervise'] = False - super(DualBranchUNet2D, self).__init__(params) - if(len(self.ft_chns) == 5): - self.up1_aux = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], 0.0, self.bilinear) - self.up2_aux = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], 0.0, self.bilinear) - self.up3_aux = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], 0.0, self.bilinear) - self.up4_aux = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], 0.0, self.bilinear) - - self.out_conv_aux = nn.Conv2d(self.ft_chns[0], self.n_class, kernel_size = 1) - - def forward(self, x): - x_shape = list(x.shape) - if(len(x_shape) == 5): - [N, C, D, H, W] = x_shape - new_shape = [N*D, C, H, W] - x = torch.transpose(x, 1, 2) - x = torch.reshape(x, new_shape) - - x0 = self.in_conv(x) - x1 = self.down1(x0) - x2 = self.down2(x1) - x3 = self.down3(x2) - if(len(self.ft_chns) == 5): - x4 = self.down4(x3) - x_d3, x_d3_aux = self.up1(x4, x3), self.up1_aux(x4, x3) - else: - x_d3, x_d3_aux = x3, x3 - - x_d2, x_d2_aux = self.up2(x_d3, x2), self.up2_aux(x_d3_aux, x2) - x_d1, x_d1_aux = self.up3(x_d2, x1), self.up3_aux(x_d2_aux, x1) - x_d0, x_d0_aux = self.up4(x_d1, x0), self.up4_aux(x_d1_aux, x0) - output, output_aux = self.out_conv(x_d0), self.out_conv_aux(x_d0_aux) - - if(len(x_shape) == 5): - new_shape = [N, D] + list(output.shape)[1:] - output = torch.reshape(output, new_shape) - output = torch.transpose(output, 1, 2) - output_aux = torch.reshape(output_aux, new_shape) - output_aux = torch.transpose(output_aux, 1, 2) - return output, output_aux \ No newline at end of file + if(self.output_mode == "average"): + return (output1 + output2)/2 + elif(self.output_mode == "first"): + return output1 + else: + return output2 diff --git a/pymic/net_run_nll/nll_dast.py b/pymic/net_run_nll/nll_dast.py new file mode 100644 index 0000000..d95eec0 --- /dev/null +++ b/pymic/net_run_nll/nll_dast.py @@ -0,0 +1,260 @@ +# -*- coding: utf-8 -*- +""" +Implementation of DAST for noise robust learning according to the following paper. + Shuojue Yang, Guotai Wang, Hui Sun, Xiangde Luo, Peng Sun, Kang Li, Qijun Wang, + Shaoting Zhang: Learning COVID-19 Pneumonia Lesion Segmentation from Imperfect + Annotations via Divergence-Aware Selective Training. + JBHI 2022. https://ieeexplore.ieee.org/document/9770406 +""" + +from __future__ import print_function, division +import random +import torch +import numpy as np +import torch.nn as nn +import torchvision.transforms as transforms +from torch.optim import lr_scheduler +from pymic.io.nifty_dataset import NiftyDataset +from pymic.loss.seg.util import get_soft_label +from pymic.loss.seg.util import reshape_prediction_and_ground_truth +from pymic.loss.seg.util import get_classwise_dice +from pymic.net_run.agent_seg import SegmentationAgent +from pymic.util.parse_config import * +from pymic.util.ramps import get_rampup_ratio + +class Rank(object): + """ + Dynamically rank the current training sample with specific metrics + """ + def __init__(self, quene_length = 100): + self.vals = [] + self.quene_length = quene_length + + def add_val(self, val): + """ + Update the quene and calculate the order of the input value. + + Return + --------- + rank: rank of the input value with a range of (0, self.quenen_length) + """ + if len(self.vals) < self.quene_length: + self.vals.append(val) + rank = -1 + else: + self.vals.pop(0) + self.vals.append(val) + assert len(self.vals) == self.quene_length + idxes = np.argsort(self.vals) + rank = np.where(idxes == self.quene_length-1)[0][0] + return rank + +class ConsistLoss(nn.Module): + def __init__(self): + super(ConsistLoss, self).__init__() + + def kl_div_map(self, input, label): + kl_map = torch.sum(label * (torch.log(label + 1e-16) - torch.log(input + 1e-16)), dim = 1) + return kl_map + + def kl_loss(self,input, target, size_average=True): + kl_div = self.kl_div_map(input, target) + if size_average: + return torch.mean(kl_div) + else: + return kl_div + + def forward(self, input1, input2, size_average = True): + kl1 = self.kl_loss(input1, input2.detach(), size_average=size_average) + kl2 = self.kl_loss(input2, input1.detach(), size_average=size_average) + return (kl1 + kl2) / 2 + +def get_ce(prob, soft_y, size_avg = True): + prob = prob * 0.999 + 5e-4 + ce = - soft_y* torch.log(prob) + ce = torch.sum(ce, dim = 1) # shape is [N] + if(size_avg): + ce = torch.mean(ce) + return ce + +@torch.no_grad() +def select_criterion(no_noisy_sample, cl_noisy_sample, label): + """ + no_noisy_sample: noisy branch's output probability for noisy sample + cl_noisy_sample: clean branch's output probability for noisy sample + label: noisy label + """ + l_n = get_ce(no_noisy_sample, label, size_avg = False) + l_c = get_ce(cl_noisy_sample, label, size_avg = False) + js_distance = ConsistLoss() + variance = js_distance(no_noisy_sample, cl_noisy_sample, size_average=False) + exp_variance = torch.exp(-16 * variance) + loss_n = torch.mean(l_c * exp_variance).item() + loss_c = torch.mean(l_n * exp_variance).item() + return loss_n, loss_c + +class NLLDAST(SegmentationAgent): + def __init__(self, config, stage = 'train'): + super(NLLDAST, self).__init__(config, stage) + self.train_set_noise = None + self.train_loader_noise = None + self.trainIter_noise = None + self.noisy_rank = None + self.clean_rank = None + + def get_noisy_dataset_from_config(self): + root_dir = self.config['dataset']['root_dir'] + modal_num = self.config['dataset'].get('modal_num', 1) + transform_names = self.config['dataset']['train_transform'] + + self.transform_list = [] + if(transform_names is None or len(transform_names) == 0): + data_transform = None + else: + transform_param = self.config['dataset'] + transform_param['task'] = 'segmentation' + for name in transform_names: + if(name not in self.transform_dict): + raise(ValueError("Undefined transform {0:}".format(name))) + one_transform = self.transform_dict[name](transform_param) + self.transform_list.append(one_transform) + data_transform = transforms.Compose(self.transform_list) + + csv_file = self.config['dataset'].get('train_csv_noise', None) + dataset = NiftyDataset(root_dir=root_dir, + csv_file = csv_file, + modal_num = modal_num, + with_label= True, + transform = data_transform ) + return dataset + + def create_dataset(self): + super(NLLDAST, self).create_dataset() + if(self.stage == 'train'): + if(self.train_set_noise is None): + self.train_set_noise = self.get_noisy_dataset_from_config() + if(self.deterministic): + def worker_init_fn(worker_id): + random.seed(self.random_seed + worker_id) + worker_init = worker_init_fn + else: + worker_init = None + + bn_train_noise = self.config['dataset']['train_batch_size_noise'] + num_worker = self.config['dataset'].get('num_workder', 16) + self.train_loader_noise = torch.utils.data.DataLoader(self.train_set_noise, + batch_size = bn_train_noise, shuffle=True, num_workers= num_worker, + worker_init_fn=worker_init) + + def training(self): + class_num = self.config['network']['class_num'] + iter_valid = self.config['training']['iter_valid'] + nll_cfg = self.config['noisy_label_learning'] + iter_max = self.config['training']['iter_max'] + rampup_start = nll_cfg.get('rampup_start', 0) + rampup_end = nll_cfg.get('rampup_end', iter_max) + train_loss = 0 + train_loss_sup = 0 + train_loss_reg = 0 + train_dice_list = [] + self.net.train() + + rank_length = nll_cfg.get("dast_rank_length", 20) + consist_loss = ConsistLoss() + for it in range(iter_valid): + try: + data_cl = next(self.trainIter) + except StopIteration: + self.trainIter = iter(self.train_loader) + data_cl = next(self.trainIter) + try: + data_no = next(self.trainIter_noise) + except StopIteration: + self.trainIter_noise = iter(self.train_loader_noise) + data_no = next(self.trainIter_noise) + + # get the inputs + x0 = self.convert_tensor_type(data_cl['image']) # clean sample + y0 = self.convert_tensor_type(data_cl['label_prob']) + x1 = self.convert_tensor_type(data_no['image']) # noisy sample + y1 = self.convert_tensor_type(data_no['label_prob']) + inputs = torch.cat([x0, x1], dim = 0).to(self.device) + y0, y1 = y0.to(self.device), y1.to(self.device) + + # zero the parameter gradients + self.optimizer.zero_grad() + + # forward + backward + optimize + b0_pred, b1_pred = self.net(inputs) + n0 = list(x0.shape)[0] # number of clean samples + b0_x0_pred = b0_pred[:n0] # predication of clean samples from clean branch + b0_x1_pred = b0_pred[n0:] # predication of noisy samples from clean branch + b1_x1_pred = b1_pred[n0:] # predication of noisy samples from noisy branch + + # supervised loss for the clean and noisy branches, respectively + loss_sup_cl = self.get_loss_value(data_cl, b0_x0_pred, y0) + loss_sup_no = self.get_loss_value(data_no, b1_x1_pred, y1) + loss_sup = (loss_sup_cl + loss_sup_no) / 2 + loss = loss_sup + + # Severe Noise supression & Supplementary Training + rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") + w_dbc = nll_cfg.get('dast_dbc_w', 0.1) * rampup_ratio + w_st = nll_cfg.get('dast_st_w', 0.1) * rampup_ratio + b1_x1_prob = nn.Softmax(dim = 1)(b1_x1_pred) + b0_x1_prob = nn.Softmax(dim = 1)(b0_x1_pred) + loss_n, loss_c = select_criterion(b1_x1_prob, b0_x1_prob, y1) + rank_n = self.noisy_rank.add_val(loss_n) + rank_c = self.clean_rank.add_val(loss_c) + if loss_n < loss_c: + if rank_c >= rank_length * 0.8: + loss_dbc = consist_loss(b1_x1_prob, b0_x1_prob) + loss = loss + loss_dbc * w_dbc + if rank_n <= 0.2 * rank_length: + b0_x1_argmax = torch.argmax(b0_x1_pred, dim = 1, keepdim = True) + b0_x1_lab = get_soft_label(b0_x1_argmax, class_num, self.tensor_type) + b1_x1_argmax = torch.argmax(b1_x1_pred, dim = 1, keepdim = True) + b1_x1_lab = get_soft_label(b1_x1_argmax, class_num, self.tensor_type) + pseudo_label = (b0_x1_lab + b1_x1_lab + y1) / 3 + sharpen = lambda p,T: p**(1.0/T)/(p**(1.0/T) + (1-p)**(1.0/T)) + b0_x1_prob = nn.Softmax(dim = 1)(b0_x1_pred) + loss_st = torch.mean(torch.abs(b0_x1_prob - sharpen(pseudo_label, 0.5))) + loss = loss + loss_st * w_st + + loss.backward() + self.optimizer.step() + if(self.scheduler is not None and \ + not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): + self.scheduler.step() + + train_loss = train_loss + loss.item() + train_loss_sup = train_loss_sup + loss_sup.item() + # train_loss_reg = train_loss_reg + loss_reg.item() + # get dice evaluation for each class in annotated images + if(isinstance(b0_x0_pred, tuple) or isinstance(b0_x0_pred, list)): + p0 = b0_x0_pred[0] + else: + p0 = b0_x0_pred + p0_argmax = torch.argmax(p0, dim = 1, keepdim = True) + p0_soft = get_soft_label(p0_argmax, class_num, self.tensor_type) + p0_soft, y0 = reshape_prediction_and_ground_truth(p0_soft, y0) + dice_list = get_classwise_dice(p0_soft, y0) + train_dice_list.append(dice_list.cpu().numpy()) + train_avg_loss = train_loss / iter_valid + train_avg_loss_sup = train_loss_sup / iter_valid + train_avg_loss_reg = train_loss_reg / iter_valid + train_cls_dice = np.asarray(train_dice_list).mean(axis = 0) + train_avg_dice = train_cls_dice.mean() + + train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, + 'loss_reg':train_avg_loss_reg, 'regular_w':w_dbc, + 'avg_dice':train_avg_dice, 'class_dice': train_cls_dice} + return train_scalers + + def train_valid(self): + self.trainIter_noise = iter(self.train_loader_noise) + nll_cfg = self.config['noisy_label_learning'] + rank_length = nll_cfg.get("dast_rank_length", 20) + self.noisy_rank = Rank(rank_length) + self.clean_rank = Rank(rank_length) + super(NLLDAST, self).train_valid() \ No newline at end of file diff --git a/pymic/net_run_nll/nll_main.py b/pymic/net_run_nll/nll_main.py index 7d8f1f8..cc07a44 100644 --- a/pymic/net_run_nll/nll_main.py +++ b/pymic/net_run_nll/nll_main.py @@ -7,9 +7,11 @@ from pymic.util.parse_config import * from pymic.net_run_nll.nll_co_teaching import NLLCoTeaching from pymic.net_run_nll.nll_trinet import NLLTriNet +from pymic.net_run_nll.nll_dast import NLLDAST NLLMethodDict = {'CoTeaching': NLLCoTeaching, - "TriNet": NLLTriNet} + "TriNet": NLLTriNet, + "DAST": NLLDAST} def main(): if(len(sys.argv) < 3): From 24c46cffc58d75d3b7f3cee8552f0d045c7d1398 Mon Sep 17 00:00:00 2001 From: taigw Date: Fri, 19 Aug 2022 16:50:19 +0800 Subject: [PATCH 23/26] Update nll_dast.py add config parameter --- pymic/net_run_nll/nll_dast.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pymic/net_run_nll/nll_dast.py b/pymic/net_run_nll/nll_dast.py index d95eec0..19a59a2 100644 --- a/pymic/net_run_nll/nll_dast.py +++ b/pymic/net_run_nll/nll_dast.py @@ -207,10 +207,11 @@ def training(self): rank_n = self.noisy_rank.add_val(loss_n) rank_c = self.clean_rank.add_val(loss_c) if loss_n < loss_c: - if rank_c >= rank_length * 0.8: + select_ratio = nll_cfg.get('dast_select_ratio', 0.2) + if rank_c >= rank_length * (1 - select_ratio): loss_dbc = consist_loss(b1_x1_prob, b0_x1_prob) loss = loss + loss_dbc * w_dbc - if rank_n <= 0.2 * rank_length: + if rank_n <= rank_length * select_ratio: b0_x1_argmax = torch.argmax(b0_x1_pred, dim = 1, keepdim = True) b0_x1_lab = get_soft_label(b0_x1_argmax, class_num, self.tensor_type) b1_x1_argmax = torch.argmax(b1_x1_pred, dim = 1, keepdim = True) From 75622c346f7e753c668ef734f0be5085d6a6074a Mon Sep 17 00:00:00 2001 From: taigw Date: Sat, 20 Aug 2022 12:01:23 +0800 Subject: [PATCH 24/26] update comment and fix typo update comment and fix typo --- pymic/net_run/agent_cls.py | 14 ++++++++------ pymic/net_run_nll/nll_clslsr.py | 3 ++- pymic/net_run_nll/nll_trinet.py | 17 +++++------------ 3 files changed, 15 insertions(+), 19 deletions(-) diff --git a/pymic/net_run/agent_cls.py b/pymic/net_run/agent_cls.py index 71d7c30..8687048 100644 --- a/pymic/net_run/agent_cls.py +++ b/pymic/net_run/agent_cls.py @@ -195,7 +195,9 @@ def train_valid(self): ckpt_dir = self.config['training']['ckpt_save_dir'] if(ckpt_dir[-1] == "/"): ckpt_dir = ckpt_dir[:-1] - ckpt_prefx = ckpt_dir.split('/')[-1] + ckpt_prefix = self.config['training'].get('ckpt_prefix', None) + if(ckpt_prefix is None): + ckpt_prefix = ckpt_dir.split('/')[-1] iter_start = self.config['training']['iter_start'] iter_max = self.config['training']['iter_max'] iter_valid = self.config['training']['iter_valid'] @@ -206,7 +208,7 @@ def train_valid(self): self.best_model_wts = None self.checkpoint = None if(iter_start > 0): - checkpoint_file = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefx, iter_start) + checkpoint_file = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, iter_start) self.checkpoint = torch.load(checkpoint_file, map_location = self.device) assert(self.checkpoint['iteration'] == iter_start) self.net.load_state_dict(self.checkpoint['model_state_dict']) @@ -237,9 +239,9 @@ def train_valid(self): 'valid_pred': valid_scalars[metrics], 'model_state_dict': self.net.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict()} - save_name = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefx, glob_it) + save_name = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, glob_it) torch.save(save_dict, save_name) - txt_file = open("{0:}/{1:}_latest.txt".format(ckpt_dir, ckpt_prefx), 'wt') + txt_file = open("{0:}/{1:}_latest.txt".format(ckpt_dir, ckpt_prefix), 'wt') txt_file.write(str(glob_it)) txt_file.close() @@ -248,9 +250,9 @@ def train_valid(self): 'valid_pred': self.max_val_score, 'model_state_dict': self.best_model_wts, 'optimizer_state_dict': self.optimizer.state_dict()} - save_name = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefx, self.max_val_it) + save_name = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, self.max_val_it) torch.save(save_dict, save_name) - txt_file = open("{0:}/{1:}_best.txt".format(ckpt_dir, ckpt_prefx), 'wt') + txt_file = open("{0:}/{1:}_best.txt".format(ckpt_dir, ckpt_prefix), 'wt') txt_file.write(str(self.max_val_it)) txt_file.close() logging.info('The best perfroming iter is {0:}, valid {1:} {2:}'.format(\ diff --git a/pymic/net_run_nll/nll_clslsr.py b/pymic/net_run_nll/nll_clslsr.py index 2894db6..9ee7182 100644 --- a/pymic/net_run_nll/nll_clslsr.py +++ b/pymic/net_run_nll/nll_clslsr.py @@ -3,7 +3,8 @@ Caculating the confidence map of labels of training samples, which is used in the method of SLSR. Minqing Zhang et al., Characterizing Label Errors: Confident Learning - for Noisy-Labeled Image Segmentation, MICCAI 2020. + for Noisy-Labeled Image Segmentation, MICCAI 2020. + https://link.springer.com/chapter/10.1007/978-3-030-59710-8_70 """ from __future__ import print_function, division diff --git a/pymic/net_run_nll/nll_trinet.py b/pymic/net_run_nll/nll_trinet.py index eb0ecdd..6af5449 100644 --- a/pymic/net_run_nll/nll_trinet.py +++ b/pymic/net_run_nll/nll_trinet.py @@ -1,13 +1,11 @@ # -*- coding: utf-8 -*- """ -Implementation of Co-teaching for learning from noisy samples for +Implementation of trinet for learning from noisy samples for segmentation tasks according to the following paper: - Bo Han et al., Co-teaching: Robust Training of Deep NeuralNetworks - with Extremely Noisy Labels, NeurIPS, 2018 -The author's original implementation was: -https://github.com/bhanML/Co-teaching - - + Tianwei Zhang, Lequan Yu, Na Hu, Su Lv, Shi Gu: + Robust Medical Image Segmentation from Non-expert Annotations with Tri-network. + MICCAI 2020. + https://link.springer.com/chapter/10.1007/978-3-030-59719-1_25 """ from __future__ import print_function, division import logging @@ -48,11 +46,6 @@ def forward(self, x): return (out1 + out2 + out3) / 3 class NLLTriNet(SegmentationAgent): - """ - Co-teaching: Robust Training of Deep Neural Networks with Extremely - Noisy Labels - https://arxiv.org/abs/1804.06872 - """ def __init__(self, config, stage = 'train'): super(NLLTriNet, self).__init__(config, stage) From 8f8eb33ff88cb657ca925e5c782a80d0fbb07c3a Mon Sep 17 00:00:00 2001 From: taigw Date: Sat, 20 Aug 2022 12:07:53 +0800 Subject: [PATCH 25/26] Update README.md --- README.md | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index d6007e8..e29abde 100644 --- a/README.md +++ b/README.md @@ -31,24 +31,26 @@ PyMIC provides flixible modules for medical image computing tasks including clas [tbx_link]:https://github.com/lanpa/tensorboardX ## Installation -Run the following command to install the current released version of PyMIC: +Run the following command to install the latest released version of PyMIC: ```bash pip install PYMIC ``` -Alternatively, you can download the source code for the latest version. Run the following command to compile and install: +Alternatively, you can download the source code for the latest dev version. Run the following command to compile and install: ```bash python setup.py install ``` -## Examples -[PyMIC_examples][examples] provides some examples of starting to use PyMIC. At the beginning, you only need to edit the configuration files to select different datasets, networks and training methods for running the code. When you are more familiar with PyMIC, you can customize different modules in the PyMIC package. You can find both types of examples: +## How to start +* [PyMIC_examples][exp_link] shows some examples of starting to use PyMIC. +* [PyMIC_doc][docs_link] provides documentation of this project. -[examples]: https://github.com/HiLab-git/PyMIC_examples +[docs_link]:https://pymic.readthedocs.io/en/latest/ +[exp_link]:https://github.com/HiLab-git/PyMIC_examples -# Projects based on PyMIC +## Projects based on PyMIC Using PyMIC, it becomes easy to develop deep learning models for different projects, such as the following: 1, [COPLE-Net][coplenet] (TMI 2020), COVID-19 Pneumonia Segmentation from CT images. From 51629241f12a83016f79d10efce494a2342b996c Mon Sep 17 00:00:00 2001 From: taigw Date: Sat, 20 Aug 2022 18:04:05 +0800 Subject: [PATCH 26/26] add init file and update version add init file and update version --- pymic/net_run_nll/__init__.py | 0 pymic/net_run_ssl/__init__.py | 0 pymic/net_run_wsl/__init__.py | 0 pymic/net_run_wsl/wsl_dmpls.py | 2 +- pymic/net_run_wsl/wsl_gatedcrf.py | 2 +- requirements.txt | 12 ++++++++++++ setup.py | 6 ++++-- 7 files changed, 18 insertions(+), 4 deletions(-) create mode 100644 pymic/net_run_nll/__init__.py create mode 100644 pymic/net_run_ssl/__init__.py create mode 100644 pymic/net_run_wsl/__init__.py create mode 100644 requirements.txt diff --git a/pymic/net_run_nll/__init__.py b/pymic/net_run_nll/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pymic/net_run_ssl/__init__.py b/pymic/net_run_ssl/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pymic/net_run_wsl/__init__.py b/pymic/net_run_wsl/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pymic/net_run_wsl/wsl_dmpls.py b/pymic/net_run_wsl/wsl_dmpls.py index a198ddc..8ee9e53 100644 --- a/pymic/net_run_wsl/wsl_dmpls.py +++ b/pymic/net_run_wsl/wsl_dmpls.py @@ -4,7 +4,7 @@ import numpy as np import random import torch -from torhc.optim import lr_scheduler +from torch.optim import lr_scheduler from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice diff --git a/pymic/net_run_wsl/wsl_gatedcrf.py b/pymic/net_run_wsl/wsl_gatedcrf.py index 2be8856..64e0f1b 100644 --- a/pymic/net_run_wsl/wsl_gatedcrf.py +++ b/pymic/net_run_wsl/wsl_gatedcrf.py @@ -3,7 +3,7 @@ import logging import numpy as np import torch -from torhc.optim import lr_scheduler +from torch.optim import lr_scheduler from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..2dc1604 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,12 @@ +matplotlib>=3.1.2 +numpy>=1.17.4 +pandas>=0.25.3 +python>=3.6 +scikit-image>=0.16.2 +scikit-learn>=0.22 +scipy>=1.3.3 +SimpleITK>=1.2.4 +tensorboard>=2.1.0 +tensorboardX>=1.9 +torch>=1.7.1 +torchvision>=0.8.2 diff --git a/setup.py b/setup.py index 498aa26..ce7271b 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ # Get the summary description = 'An open-source deep learning platform' + \ - ' for medical image computing' + ' for annotation-efficient medical image computing' # Get the long description with open('README.md', encoding='utf-8') as f: @@ -11,7 +11,7 @@ setuptools.setup( name = 'PYMIC', - version = "0.2.5", + version = "0.3.0", author ='PyMIC Consortium', author_email = 'wguotai@gmail.com', description = description, @@ -31,6 +31,8 @@ 'console_scripts': [ 'pymic_run = pymic.net_run.net_run:main', 'pymic_ssl = pymic.net_run_ssl.ssl_main:main', + 'pymic_wsl = pymic.net_run_wsl.wsl_main:main', + 'pymic_nll = pymic.net_run_nll.nll_main:main', 'pymic_eval_cls = pymic.util.evaluation_cls:main', 'pymic_eval_seg = pymic.util.evaluation_seg:main' ],