diff --git a/README.md b/README.md index ec004c9..3ce5929 100644 --- a/README.md +++ b/README.md @@ -1,43 +1,45 @@ # PyMIC: A Pytorch-Based Toolkit for Medical Image Computing -PyMIC is a pytorch-based toolkit for medical image computing with deep learning. Despite that pytorch is a fantastic platform for deep learning, using it for medical image computing is not straightforward as medical images are often with higher dimension, multiple modalities and low contrast. The toolkit is developed to facilitate medical image computing researchers so that training and testing deep learning models become easier. It is very friendly to researchers who are new to this area. Even without writing any code, you can use PyMIC commands to train and test a model by simply editing configure files. +PyMIC is a pytorch-based toolkit for medical image computing with annotation-efficient deep learning. Despite that pytorch is a fantastic platform for deep learning, using it for medical image computing is not straightforward as medical images are often with high dimension and large volume, multiple modalities and difficulies in annotating. This toolkit is developed to facilitate medical image computing researchers so that training and testing deep learning models become easier. It is very friendly to researchers who are new to this area. Even without writing any code, you can use PyMIC commands to train and test a model by simply editing configuration files. PyMIC is developed to support learning with imperfect labels, including semi-supervised and weakly supervised learning, and learning with noisy annotations. Currently PyMIC supports 2D/3D medical image classification and segmentation, and it is still under development. It was originally developed for COVID-19 pneumonia lesion segmentation from CT images. If you use this toolkit, please cite the following paper: - * G. Wang, X. Liu, C. Li, Z. Xu, J. Ruan, H. Zhu, T. Meng, K. Li, N. Huang, S. Zhang. [A Noise-robust Framework for Automatic Segmentation of COVID-19 Pneumonia Lesions from CT Images.][tmi2020] IEEE Transactions on Medical Imaging. 39(8):2653-2663, 2020. DOI: [10.1109/TMI.2020.3000314][tmi2020] [tmi2020]:https://ieeexplore.ieee.org/document/9109297 -# Advantages -PyMIC provides some basic modules for medical image computing that can be share by different applications. We currently provide the following functions: +# Features +PyMIC provides flixible modules for medical image computing tasks including classification and segmentation. It currently provides the following functions: +* Support for annotation-efficient image segmentation, especially for semi-supervised, weakly-supervised and noisy-label learning. +* User friendly: For beginners, you only need to edit the configuration files for model training and inference, without writing code. For advanced users, you can customize different modules (networks, loss functions, training pipeline, etc) and easily integrate them into PyMIC. * Easy-to-use I/O interface to read and write different 2D and 3D images. +* Various data pre-processing/transformation methods before sending a tensor into a network. +* Implementation of typical neural networks for medical image segmentation. * Re-useable training and testing pipeline that can be transferred to different tasks. -* Various data pre-processing methods before sending a tensor into a network. -* Implementation of loss functions, especially for image segmentation. -* Implementation of evaluation metrics to get quantitative evaluation of your methods (for segmentation). +* Evaluation metrics for quantitative evaluation of your methods. # Usage ## Requirement * [Pytorch][torch_link] version >=1.0.1 * [TensorboardX][tbx_link] to visualize training performance * Some common python packages such as Numpy, Pandas, SimpleITK +* See `requirements.txt` for details. [torch_link]:https://pytorch.org/ [tbx_link]:https://github.com/lanpa/tensorboardX ## Installation -Run the following command to install the current released version of PyMIC: +Run the following command to install the latest released version of PyMIC: ```bash pip install PYMIC ``` -To install a specific version of PYMIC such as 0.2.4, run: +To install a specific version of PYMIC such as 0.3.0, run: ```bash -pip install PYMIC==0.2.4 +pip install PYMIC==0.3.0 ``` Alternatively, you can download the source code for the latest version. Run the following command to compile and install: @@ -45,12 +47,14 @@ Alternatively, you can download the source code for the latest version. Run the python setup.py install ``` -## Examples -[PyMIC_examples][examples] provides some examples of starting to use PyMIC. For beginners, you only need to simply change the configuration files to select different datasets, networks and training methods for running the code. For advanced users, you can develop your own modules based on this package. You can find both types of examples +## How to start +* [PyMIC_examples][exp_link] shows some examples of starting to use PyMIC. +* [PyMIC_doc][docs_link] provides documentation of this project. -[examples]: https://github.com/HiLab-git/PyMIC_examples +[docs_link]:https://pymic.readthedocs.io/en/latest/ +[exp_link]:https://github.com/HiLab-git/PyMIC_examples -# Projects based on PyMIC +## Projects based on PyMIC Using PyMIC, it becomes easy to develop deep learning models for different projects, such as the following: 1, [MyoPS][myops] Winner of the MICCAI 2020 myocardial pathology segmentation (MyoPS) Challenge. diff --git a/pymic/loss/loss_dict_seg.py b/pymic/loss/loss_dict_seg.py index 929ec43..a8a53ad 100644 --- a/pymic/loss/loss_dict_seg.py +++ b/pymic/loss/loss_dict_seg.py @@ -1,14 +1,14 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division import torch.nn as nn -from pymic.loss.seg.ce import CrossEntropyLoss, GeneralizedCrossEntropyLoss +from pymic.loss.seg.ce import CrossEntropyLoss, GeneralizedCELoss from pymic.loss.seg.dice import DiceLoss, FocalDiceLoss, NoiseRobustDiceLoss from pymic.loss.seg.slsr import SLSRLoss from pymic.loss.seg.exp_log import ExpLogLoss from pymic.loss.seg.mse import MSELoss, MAELoss SegLossDict = {'CrossEntropyLoss': CrossEntropyLoss, - 'GeneralizedCrossEntropyLoss': GeneralizedCrossEntropyLoss, + 'GeneralizedCELoss': GeneralizedCELoss, 'SLSRLoss': SLSRLoss, 'DiceLoss': DiceLoss, 'FocalDiceLoss': FocalDiceLoss, diff --git a/pymic/loss/seg/ce.py b/pymic/loss/seg/ce.py index dadeba7..cdef1a0 100644 --- a/pymic/loss/seg/ce.py +++ b/pymic/loss/seg/ce.py @@ -6,7 +6,7 @@ from pymic.loss.seg.util import reshape_tensor_to_2D class CrossEntropyLoss(nn.Module): - def __init__(self, params): + def __init__(self, params = None): super(CrossEntropyLoss, self).__init__() if(params is None): self.softmax = True @@ -59,34 +59,36 @@ def forward(self, loss_input_dict): ce = torch.mean(ce) return ce -class GeneralizedCrossEntropyLoss(nn.Module): +class GeneralizedCELoss(nn.Module): """ Generalized cross entropy loss to deal with noisy labels. Z. Zhang et al. Generalized Cross Entropy Loss for Training Deep Neural Networks with Noisy Labels, NeurIPS 2018. """ def __init__(self, params): - super(GeneralizedCrossEntropyLoss, self).__init__() - self.enable_pix_weight = params['GeneralizedCrossEntropyLoss_Enable_Pixel_Weight'.lower()] - self.enable_cls_weight = params['GeneralizedCrossEntropyLoss_Enable_Class_Weight'.lower()] - self.q = params['GeneralizedCrossEntropyLoss_q'.lower()] + """ + q: in (0, 1), becmomes MAE when q = 1 + """ + super(GeneralizedCELoss, self).__init__() + self.enable_pix_weight = params.get('GeneralizedCELoss_Enable_Pixel_Weight', False) + self.enable_cls_weight = params.get('GeneralizedCELoss_Enable_Class_Weight', False) + self.q = params.get('GeneralizedCELoss_q', 0.5) + self.softmax = params.get('loss_softmax', True) def forward(self, loss_input_dict): predict = loss_input_dict['prediction'] - soft_y = loss_input_dict['ground_truth'] - pix_w = loss_input_dict['pixel_weight'] - cls_w = loss_input_dict['class_weight'] - softmax = loss_input_dict['softmax'] + soft_y = loss_input_dict['ground_truth'] if(isinstance(predict, (list, tuple))): predict = predict[0] - if(softmax): + if(self.softmax): predict = nn.Softmax(dim = 1)(predict) predict = reshape_tensor_to_2D(predict) soft_y = reshape_tensor_to_2D(soft_y) gce = (1.0 - torch.pow(predict, self.q)) / self.q * soft_y if(self.enable_cls_weight): + cls_w = loss_input_dict.get('class_weight', None) if(cls_w is None): raise ValueError("Class weight is enabled but not defined") gce = torch.sum(gce * cls_w, dim = 1) @@ -94,6 +96,7 @@ def forward(self, loss_input_dict): gce = torch.sum(gce, dim = 1) if(self.enable_pix_weight): + pix_w = loss_input_dict.get('pixel_weight', None) if(pix_w is None): raise ValueError("Pixel weight is enabled but not defined") pix_w = reshape_tensor_to_2D(pix_w) diff --git a/pymic/loss/seg/mumford_shah.py b/pymic/loss/seg/mumford_shah.py index eeaa250..f167b71 100644 --- a/pymic/loss/seg/mumford_shah.py +++ b/pymic/loss/seg/mumford_shah.py @@ -4,37 +4,14 @@ import torch import torch.nn as nn -class DiceLoss(nn.Module): - def __init__(self, params = None): - super(DiceLoss, self).__init__() - if(params is None): - self.softmax = True - else: - self.softmax = params.get('loss_softmax', True) - - def forward(self, loss_input_dict): - predict = loss_input_dict['prediction'] - soft_y = loss_input_dict['ground_truth'] - - if(isinstance(predict, (list, tuple))): - predict = predict[0] - if(self.softmax): - predict = nn.Softmax(dim = 1)(predict) - predict = reshape_tensor_to_2D(predict) - soft_y = reshape_tensor_to_2D(soft_y) - dice_score = get_classwise_dice(predict, soft_y) - dice_loss = 1.0 - dice_score.mean() - return dice_loss - class MumfordShahLoss(nn.Module): """ Implementation of Mumford Shah Loss in this paper: - Boah Kim and Jong Chul Ye, Mumford–Shah Loss Functional + Boah Kim and Jong Chul Ye: Mumford–Shah Loss Functional for Image Segmentation With Deep Learning. IEEE TIP, 2019. The oringial implementation is availabel at: https://github.com/jongcye/CNN_MumfordShah_Loss - - currently only 2D version is supported. + Currently only 2D version is supported. """ def __init__(self, params = None): super(MumfordShahLoss, self).__init__() diff --git a/pymic/loss/seg/slsr.py b/pymic/loss/seg/slsr.py index 6ad60b3..706d2fc 100644 --- a/pymic/loss/seg/slsr.py +++ b/pymic/loss/seg/slsr.py @@ -2,8 +2,10 @@ """ Spatial Label Smoothing Regularization (SLSR) loss for learning from noisy annotatins according to the following paper: - Minqing Zhang, Jiantao Gao et al., Characterizing Label Errors: - Confident Learning for Noisy-Labeled Image Segmentation, MICCAI 2020. + Minqing Zhang, Jiantao Gao et al.: + Characterizing Label Errors: Confident Learning for Noisy-Labeled Image + Segmentation, MICCAI 2020. + https://link.springer.com/chapter/10.1007/978-3-030-59710-8_70 """ from __future__ import print_function, division @@ -17,7 +19,7 @@ def __init__(self, params): if(params is None): params = {} self.softmax = params.get('loss_softmax', True) - self.epsilon = params.get('slsrloss_softmax', 0.25) + self.epsilon = params.get('slsrloss_epsilon', 0.25) def forward(self, loss_input_dict): predict = loss_input_dict['prediction'] @@ -35,7 +37,6 @@ def forward(self, loss_input_dict): soft_y = reshape_tensor_to_2D(soft_y) if(pix_w is not None): pix_w = reshape_tensor_to_2D(pix_w > 0).float() - # smooth labels for pixels in the unconfident mask smooth_y = (soft_y - 0.5) * (0.5 - self.epsilon) / 0.5 + 0.5 smooth_y = pix_w * smooth_y + (1 - pix_w) * soft_y diff --git a/pymic/net/net2d/unet2d.py b/pymic/net/net2d/unet2d.py index 0cc607f..a361f0f 100644 --- a/pymic/net/net2d/unet2d.py +++ b/pymic/net/net2d/unet2d.py @@ -72,6 +72,67 @@ def forward(self, x1, x2): x = torch.cat([x2, x1], dim=1) return self.conv(x) +class Encoder(nn.Module): + def __init__(self, params): + super(Encoder, self).__init__() + self.params = params + self.in_chns = self.params['in_chns'] + self.ft_chns = self.params['feature_chns'] + self.dropout = self.params['dropout'] + assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) + + self.in_conv= ConvBlock(self.in_chns, self.ft_chns[0], self.dropout[0]) + self.down1 = DownBlock(self.ft_chns[0], self.ft_chns[1], self.dropout[1]) + self.down2 = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dropout[2]) + self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dropout[3]) + if(len(self.ft_chns) == 5): + self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4]) + + def forward(self, x): + x0 = self.in_conv(x) + x1 = self.down1(x0) + x2 = self.down2(x1) + x3 = self.down3(x2) + output = [x0, x1, x2, x3] + if(len(self.ft_chns) == 5): + x4 = self.down4(x3) + output.append(x4) + return output + +class Decoder(nn.Module): + def __init__(self, params): + super(Decoder, self).__init__() + self.params = params + self.in_chns = self.params['in_chns'] + self.ft_chns = self.params['feature_chns'] + self.dropout = self.params['dropout'] + self.n_class = self.params['class_num'] + self.bilinear = self.params['bilinear'] + + assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) + + if(len(self.ft_chns) == 5): + self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], self.dropout[3], self.bilinear) + self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], self.dropout[2], self.bilinear) + self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.bilinear) + self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.bilinear) + self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, kernel_size = 1) + + def forward(self, x): + if(len(self.ft_chns) == 5): + assert(len(x) == 5) + x0, x1, x2, x3, x4 = x + x_d3 = self.up1(x4, x3) + else: + assert(len(x) == 4) + x0, x1, x2, x3 = x + x_d3 = x3 + x_d2 = self.up2(x_d3, x2) + x_d1 = self.up3(x_d2, x1) + x_d0 = self.up4(x_d1, x0) + output = self.out_conv(x_d0) + return output + class UNet2D(nn.Module): def __init__(self, params): super(UNet2D, self).__init__() @@ -91,10 +152,10 @@ def __init__(self, params): self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dropout[3]) if(len(self.ft_chns) == 5): self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4]) - self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], 0.0, self.bilinear) - self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], 0.0, self.bilinear) - self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], 0.0, self.bilinear) - self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], 0.0, self.bilinear) + self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], self.dropout[3], self.bilinear) + self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], self.dropout[2], self.bilinear) + self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], self.dropout[1], self.bilinear) + self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], self.dropout[0], self.bilinear) self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, kernel_size = 1) if(self.deep_sup): diff --git a/pymic/net/net2d/unet2d_cct.py b/pymic/net/net2d/unet2d_cct.py new file mode 100644 index 0000000..f7558bc --- /dev/null +++ b/pymic/net/net2d/unet2d_cct.py @@ -0,0 +1,134 @@ +# -*- coding: utf-8 -*- +""" +An modification the U-Net with auxiliary decoders according to +the CCT paper: + Yassine Ouali, Celine Hudelot and Myriam Tami: + Semi-Supervised Semantic Segmentation With Cross-Consistency Training. + CVPR 2020. + https://arxiv.org/abs/2003.09005 +Code adapted from: https://github.com/yassouali/CCT +""" +from __future__ import print_function, division + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from torch.distributions.uniform import Uniform +from pymic.net.net2d.unet2d import Encoder, Decoder + +def _l2_normalize(d): + # Normalizing per batch axis + d_reshaped = d.view(d.shape[0], -1, *(1 for _ in range(d.dim() - 2))) + d /= torch.norm(d_reshaped, dim=1, keepdim=True) + 1e-8 + return d + + + +def get_r_adv(x_list, decoder, it=1, xi=1e-1, eps=10.0): + """ + Virtual Adversarial Training according to + https://arxiv.org/abs/1704.03976 + """ + x_detached = [item.detach() for item in x_list] + xe_detached = x_detached[-1] + with torch.no_grad(): + pred = F.softmax(decoder(x_detached), dim=1) + + d = torch.rand(x_list[-1].shape).sub(0.5).to(x_list[-1].device) + d = _l2_normalize(d) + + for _ in range(it): + d.requires_grad_() + x_detached[-1] = xe_detached + xi * d + pred_hat = decoder(x_detached) + logp_hat = F.log_softmax(pred_hat, dim=1) + adv_distance = F.kl_div(logp_hat, pred, reduction='batchmean') + adv_distance.backward() + d = _l2_normalize(d.grad) + decoder.zero_grad() + + r_adv = d * eps + return x_list[-1] + r_adv + + +class AuxiliaryDecoder(nn.Module): + def __init__(self, params, aux_type): + super(AuxiliaryDecoder, self).__init__() + self.params = params + self.decoder = Decoder(params) + self.aux_type = aux_type + uniform_range = params.get("Uniform_range".lower(), 0.3) + self.uni_dist = Uniform(-uniform_range, uniform_range) + + def feature_drop(self, x): + attention = torch.mean(x, dim=1, keepdim=True) + max_val, _ = torch.max(attention.view(x.size(0), -1), dim=1, keepdim=True) + threshold = max_val * np.random.uniform(0.7, 0.9) + threshold = threshold.view(x.size(0), 1, 1, 1).expand_as(attention) + drop_mask = (attention < threshold).float() + return x.mul(drop_mask) + + def feature_based_noise(self, x): + noise_vector = self.uni_dist.sample(x.shape[1:]).to(x.device).unsqueeze(0) + x_noise = x.mul(noise_vector) + x + return x_noise + + def forward(self, x): + if(self.aux_type == "DropOut"): + pass + elif(self.aux_type == "FeatureDrop"): + x[-1] = self.feature_drop(x[-1]) + elif(self.aux_type == "FeatureNoise"): + x[-1] = self.feature_based_noise(x[-1]) + elif(self.aux_type == "VAT"): + it = self.params.get("VAT_it".lower(), 2) + xi = self.params.get("VAT_xi".lower(), 1e-6) + eps= self.params.get("VAT_eps".lower(), 2.0) + x[-1] = get_r_adv(x, self.decoder, it, xi, eps) + else: + raise ValueError("Undefined auxiliary decoder type {0:}".format(self.aux_type)) + + output = self.decoder(x) + return output + + +class UNet2D_CCT(nn.Module): + def __init__(self, params): + super(UNet2D_CCT, self).__init__() + self.params = params + self.encoder = Encoder(params) + self.decoder = Decoder(params) + aux_names = params.get("CCT_aux_decoders".lower(), None) + if aux_names is None: + aux_names = ["DropOut", "FeatureDrop", "FeatureNoise", "VAT"] + aux_decoders = [] + for aux_name in aux_names: + aux_decoders.append(AuxiliaryDecoder(params, aux_name)) + self.aux_decoders = nn.ModuleList(aux_decoders) + + + def forward(self, x): + x_shape = list(x.shape) + if(len(x_shape) == 5): + [N, C, D, H, W] = x_shape + new_shape = [N*D, C, H, W] + x = torch.transpose(x, 1, 2) + x = torch.reshape(x, new_shape) + + f = self.encoder(x) + output = self.decoder(f) + if(len(x_shape) == 5): + new_shape = [N, D] + list(output.shape)[1:] + output = torch.reshape(output, new_shape) + output = torch.transpose(output, 1, 2) + + if(self.training): + aux_outputs = [aux_d(f) for aux_d in self.aux_decoders] + if(len(x_shape) == 5): + for i in range(len(aux_outputs)): + aux_outi = torch.reshape(aux_outputs[i], new_shape) + aux_outputs[i] = torch.transpose(aux_outi, 1, 2) + return output, aux_outputs + else: + return output \ No newline at end of file diff --git a/pymic/net/net2d/unet2d_dual_branch.py b/pymic/net/net2d/unet2d_dual_branch.py index 59ec138..3531c89 100644 --- a/pymic/net/net2d/unet2d_dual_branch.py +++ b/pymic/net/net2d/unet2d_dual_branch.py @@ -11,21 +11,15 @@ import torch import torch.nn as nn -import numpy as np -from torch.nn.functional import interpolate from pymic.net.net2d.unet2d import * -class DualBranchUNet2D(UNet2D): +class UNet2D_DualBranch(nn.Module): def __init__(self, params): - params['deep_supervise'] = False - super(DualBranchUNet2D, self).__init__(params) - if(len(self.ft_chns) == 5): - self.up1_aux = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], 0.0, self.bilinear) - self.up2_aux = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], 0.0, self.bilinear) - self.up3_aux = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], 0.0, self.bilinear) - self.up4_aux = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], 0.0, self.bilinear) - - self.out_conv_aux = nn.Conv2d(self.ft_chns[0], self.n_class, kernel_size = 1) + super(UNet2D_DualBranch, self).__init__() + self.output_mode = params.get("output_mode", "average") + self.encoder = Encoder(params) + self.decoder1 = Decoder(params) + self.decoder2 = Decoder(params) def forward(self, x): x_shape = list(x.shape) @@ -35,25 +29,22 @@ def forward(self, x): x = torch.transpose(x, 1, 2) x = torch.reshape(x, new_shape) - x0 = self.in_conv(x) - x1 = self.down1(x0) - x2 = self.down2(x1) - x3 = self.down3(x2) - if(len(self.ft_chns) == 5): - x4 = self.down4(x3) - x_d3, x_d3_aux = self.up1(x4, x3), self.up1_aux(x4, x3) - else: - x_d3, x_d3_aux = x3, x3 - - x_d2, x_d2_aux = self.up2(x_d3, x2), self.up2_aux(x_d3_aux, x2) - x_d1, x_d1_aux = self.up3(x_d2, x1), self.up3_aux(x_d2_aux, x1) - x_d0, x_d0_aux = self.up4(x_d1, x0), self.up4_aux(x_d1_aux, x0) - output, output_aux = self.out_conv(x_d0), self.out_conv_aux(x_d0_aux) - + f = self.encoder(x) + output1 = self.decoder1(f) + output2 = self.decoder2(f) if(len(x_shape) == 5): - new_shape = [N, D] + list(output.shape)[1:] - output = torch.reshape(output, new_shape) - output = torch.transpose(output, 1, 2) - output_aux = torch.reshape(output_aux, new_shape) - output_aux = torch.transpose(output_aux, 1, 2) - return output, output_aux \ No newline at end of file + new_shape = [N, D] + list(output1.shape)[1:] + output1 = torch.reshape(output1, new_shape) + output1 = torch.transpose(output1, 1, 2) + output2 = torch.reshape(output2, new_shape) + output2 = torch.transpose(output2, 1, 2) + + if(self.training): + return output1, output2 + else: + if(self.output_mode == "average"): + return (output1 + output2)/2 + elif(self.output_mode == "first"): + return output1 + else: + return output2 diff --git a/pymic/net/net2d/unet2d_scse.py b/pymic/net/net2d/unet2d_scse.py index 95b25b1..6f9d3f7 100644 --- a/pymic/net/net2d/unet2d_scse.py +++ b/pymic/net/net2d/unet2d_scse.py @@ -79,10 +79,10 @@ def __init__(self, params): self.down2 = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dropout[2]) self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dropout[3]) self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4]) - self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], dropout_p = 0.0) - self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], dropout_p = 0.0) - self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], dropout_p = 0.0) - self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], dropout_p = 0.0) + self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], dropout_p = self.dropout[3]) + self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], dropout_p = self.dropout[2]) + self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], dropout_p = self.dropout[1]) + self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], dropout_p = self.dropout[0]) self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, kernel_size = 3, padding = 1) diff --git a/pymic/net/net3d/unet2d5.py b/pymic/net/net3d/unet2d5.py index 4ed8d7e..9e6a72d 100644 --- a/pymic/net/net3d/unet2d5.py +++ b/pymic/net/net3d/unet2d5.py @@ -149,13 +149,13 @@ def __init__(self, params): self.block3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dims[3], self.dropout[3], True) self.block4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dims[4], self.dropout[4], False) self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], - self.dims[3], dropout_p = 0.0, bilinear = self.bilinear) + self.dims[3], dropout_p = self.dropout[3], bilinear = self.bilinear) self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], - self.dims[2], dropout_p = 0.0, bilinear = self.bilinear) + self.dims[2], dropout_p = self.dropout[2], bilinear = self.bilinear) self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], - self.dims[1], dropout_p = 0.0, bilinear = self.bilinear) + self.dims[1], dropout_p = self.dropout[1], bilinear = self.bilinear) self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], - self.dims[0], dropout_p = 0.0, bilinear = self.bilinear) + self.dims[0], dropout_p = self.dropout[0], bilinear = self.bilinear) self.out_conv = nn.Conv3d(self.ft_chns[0], self.n_class, kernel_size = (1, 3, 3), padding = (0, 1, 1)) diff --git a/pymic/net/net3d/unet3d.py b/pymic/net/net3d/unet3d.py index a37204e..fdedf4d 100644 --- a/pymic/net/net3d/unet3d.py +++ b/pymic/net/net3d/unet3d.py @@ -96,7 +96,6 @@ def __init__(self, params): self.n_class = self.params['class_num'] self.trilinear = self.params['trilinear'] self.deep_sup = self.params['deep_supervise'] - self.stage = self.params['stage'] assert(len(self.ft_chns) == 5 or len(self.ft_chns) == 4) self.in_conv= ConvBlock(self.in_chns, self.ft_chns[0], self.dropout[0]) @@ -106,13 +105,13 @@ def __init__(self, params): if(len(self.ft_chns) == 5): self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4]) self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], - dropout_p = 0.0, trilinear=self.trilinear) + dropout_p = self.dropout[3], trilinear=self.trilinear) self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], - dropout_p = 0.0, trilinear=self.trilinear) + dropout_p = self.dropout[2], trilinear=self.trilinear) self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], - dropout_p = 0.0, trilinear=self.trilinear) + dropout_p = self.dropout[1], trilinear=self.trilinear) self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], - dropout_p = 0.0, trilinear=self.trilinear) + dropout_p = self.dropout[0], trilinear=self.trilinear) self.out_conv = nn.Conv3d(self.ft_chns[0], self.n_class, kernel_size = 1) if(self.deep_sup): @@ -134,7 +133,7 @@ def forward(self, x): x_d1 = self.up3(x_d2, x1) x_d0 = self.up4(x_d1, x0) output = self.out_conv(x_d0) - if(self.deep_sup and self.stage == "train"): + if(self.deep_sup): out_shape = list(output.shape)[2:] output1 = self.out_conv1(x_d1) output1 = interpolate(output1, out_shape, mode = 'trilinear') diff --git a/pymic/net/net3d/unet3d_scse.py b/pymic/net/net3d/unet3d_scse.py index 5832830..0f15e25 100644 --- a/pymic/net/net3d/unet3d_scse.py +++ b/pymic/net/net3d/unet3d_scse.py @@ -78,10 +78,10 @@ def __init__(self, params): self.down2 = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dropout[2]) self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dropout[3]) self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4]) - self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], dropout_p = 0.0) - self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], dropout_p = 0.0) - self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], dropout_p = 0.0) - self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], dropout_p = 0.0) + self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], dropout_p = self.dropout[3]) + self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], dropout_p = self.dropout[2]) + self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], dropout_p = self.dropout[1]) + self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], dropout_p = self.dropout[0]) self.out_conv = nn.Conv3d(self.ft_chns[0], self.n_class, kernel_size = 3, padding = 1) diff --git a/pymic/net/net_dict_seg.py b/pymic/net/net_dict_seg.py index 55711d4..0ee554e 100644 --- a/pymic/net/net_dict_seg.py +++ b/pymic/net/net_dict_seg.py @@ -1,8 +1,9 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division from pymic.net.net2d.unet2d import UNet2D -from pymic.net.net2d.unet2d_dual_branch import DualBranchUNet2D +from pymic.net.net2d.unet2d_dual_branch import UNet2D_DualBranch from pymic.net.net2d.unet2d_urpc import UNet2D_URPC +from pymic.net.net2d.unet2d_cct import UNet2D_CCT from pymic.net.net2d.cople_net import COPLENet from pymic.net.net2d.unet2d_attention import AttentionUNet2D from pymic.net.net2d.unet2d_nest import NestedUNet2D @@ -13,8 +14,9 @@ SegNetDict = { 'UNet2D': UNet2D, - 'DualBranchUNet2D': DualBranchUNet2D, + 'UNet2D_DualBranch': UNet2D_DualBranch, 'UNet2D_URPC': UNet2D_URPC, + 'UNet2D_CCT': UNet2D_CCT, 'COPLENet': COPLENet, 'AttentionUNet2D': AttentionUNet2D, 'NestedUNet2D': NestedUNet2D, diff --git a/pymic/net_run/agent_abstract.py b/pymic/net_run/agent_abstract.py index 420e2f9..8ffadc1 100644 --- a/pymic/net_run/agent_abstract.py +++ b/pymic/net_run/agent_abstract.py @@ -3,11 +3,12 @@ import os import random +import logging import torch import numpy as np import torch.optim as optim from abc import ABCMeta, abstractmethod -from pymic.net_run.get_optimizer import get_optimiser +from pymic.net_run.get_optimizer import get_lr_scheduler, get_optimizer def seed_torch(seed=1): random.seed(seed) @@ -42,7 +43,7 @@ def __init__(self, config, stage = 'train'): self.random_seed = config['training'].get('random_seed', 1) if(self.deterministic): seed_torch(self.random_seed) - print("deterministric is true") + logging.info("deterministric is true") def set_datasets(self, train_set, valid_set, test_set): self.train_set = train_set @@ -71,7 +72,9 @@ def get_checkpoint_name(self): ckpt_mode = self.config['testing']['ckpt_mode'] if(ckpt_mode == 0 or ckpt_mode == 1): ckpt_dir = self.config['training']['ckpt_save_dir'] - ckpt_prefix = ckpt_dir.split('/')[-1] + ckpt_prefix = self.config['training'].get('ckpt_prefix', None) + if(ckpt_prefix is None): + ckpt_prefix = ckpt_dir.split('/')[-1] txt_name = ckpt_dir + '/' + ckpt_prefix txt_name += "_latest.txt" if ckpt_mode == 0 else "_best.txt" with open(txt_name, 'r') as txt_file: @@ -145,19 +148,17 @@ def worker_init_fn(worker_id): batch_size = bn_test, shuffle=False, num_workers= bn_test) def create_optimizer(self, params): + opt_params = self.config['training'] if(self.optimizer is None): - self.optimizer = get_optimiser(self.config['training']['optimizer'], - params, - self.config['training']) + self.optimizer = get_optimizer(opt_params['optimizer'], + params, opt_params) last_iter = -1 if(self.checkpoint is not None): self.optimizer.load_state_dict(self.checkpoint['optimizer_state_dict']) last_iter = self.checkpoint['iteration'] - 1 if(self.scheduler is None): - self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer, - self.config['training']['lr_milestones'], - self.config['training']['lr_gamma'], - last_epoch = last_iter) + opt_params["last_iter"] = last_iter + self.scheduler = get_lr_scheduler(self.optimizer, opt_params) def convert_tensor_type(self, input_tensor): if(self.tensor_type == 'float'): diff --git a/pymic/net_run/agent_cls.py b/pymic/net_run/agent_cls.py index 4a80532..8687048 100644 --- a/pymic/net_run/agent_cls.py +++ b/pymic/net_run/agent_cls.py @@ -3,6 +3,7 @@ import copy import csv +import logging import time import torch from torchvision import transforms @@ -71,7 +72,7 @@ def create_network(self): else: self.net.double() param_number = sum(p.numel() for p in self.net.parameters() if p.requires_grad) - print('parameter number:', param_number) + logging.info('parameter number {0:}'.format(param_number)) def get_parameters_to_update(self): params = self.net.get_parameters_to_update() @@ -176,10 +177,10 @@ def write_scalars(self, train_scalars, valid_scalars, glob_it): self.summ_writer.add_scalars('loss', loss_scalar, glob_it) self.summ_writer.add_scalars(metrics, acc_scalar, glob_it) - print("{0:} it {1:}".format(str(datetime.now())[:-7], glob_it)) - print('train loss {0:.4f}, avg {1:} {2:.4f}'.format( + logging.info("{0:} it {1:}".format(str(datetime.now())[:-7], glob_it)) + logging.info('train loss {0:.4f}, avg {1:} {2:.4f}'.format( train_scalars['loss'], metrics, train_scalars[metrics])) - print('valid loss {0:.4f}, avg {1:} {2:.4f}'.format( + logging.info('valid loss {0:.4f}, avg {1:} {2:.4f}'.format( valid_scalars['loss'], metrics, valid_scalars[metrics])) def train_valid(self): @@ -194,7 +195,9 @@ def train_valid(self): ckpt_dir = self.config['training']['ckpt_save_dir'] if(ckpt_dir[-1] == "/"): ckpt_dir = ckpt_dir[:-1] - ckpt_prefx = ckpt_dir.split('/')[-1] + ckpt_prefix = self.config['training'].get('ckpt_prefix', None) + if(ckpt_prefix is None): + ckpt_prefix = ckpt_dir.split('/')[-1] iter_start = self.config['training']['iter_start'] iter_max = self.config['training']['iter_max'] iter_valid = self.config['training']['iter_valid'] @@ -205,7 +208,7 @@ def train_valid(self): self.best_model_wts = None self.checkpoint = None if(iter_start > 0): - checkpoint_file = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefx, iter_start) + checkpoint_file = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, iter_start) self.checkpoint = torch.load(checkpoint_file, map_location = self.device) assert(self.checkpoint['iteration'] == iter_start) self.net.load_state_dict(self.checkpoint['model_state_dict']) @@ -218,7 +221,7 @@ def train_valid(self): self.trainIter = iter(self.train_loader) - print("{0:} training start".format(str(datetime.now())[:-7])) + logging.info("{0:} training start".format(str(datetime.now())[:-7])) self.summ_writer = SummaryWriter(self.config['training']['ckpt_save_dir']) for it in range(iter_start, iter_max, iter_valid): train_scalars = self.training() @@ -236,9 +239,9 @@ def train_valid(self): 'valid_pred': valid_scalars[metrics], 'model_state_dict': self.net.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict()} - save_name = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefx, glob_it) + save_name = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, glob_it) torch.save(save_dict, save_name) - txt_file = open("{0:}/{1:}_latest.txt".format(ckpt_dir, ckpt_prefx), 'wt') + txt_file = open("{0:}/{1:}_latest.txt".format(ckpt_dir, ckpt_prefix), 'wt') txt_file.write(str(glob_it)) txt_file.close() @@ -247,12 +250,12 @@ def train_valid(self): 'valid_pred': self.max_val_score, 'model_state_dict': self.best_model_wts, 'optimizer_state_dict': self.optimizer.state_dict()} - save_name = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefx, self.max_val_it) + save_name = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, self.max_val_it) torch.save(save_dict, save_name) - txt_file = open("{0:}/{1:}_best.txt".format(ckpt_dir, ckpt_prefx), 'wt') + txt_file = open("{0:}/{1:}_best.txt".format(ckpt_dir, ckpt_prefix), 'wt') txt_file.write(str(self.max_val_it)) txt_file.close() - print('The best perfroming iter is {0:}, valid {1:} {2:}'.format(\ + logging.info('The best perfroming iter is {0:}, valid {1:} {2:}'.format(\ self.max_val_it, metrics, self.max_val_score)) self.summ_writer.close() diff --git a/pymic/net_run/agent_seg.py b/pymic/net_run/agent_seg.py index e4bb97c..5a53f3e 100644 --- a/pymic/net_run/agent_seg.py +++ b/pymic/net_run/agent_seg.py @@ -10,6 +10,7 @@ import numpy as np import torch.nn as nn import torch.optim as optim +from torch.optim import lr_scheduler import torch.nn.functional as F from datetime import datetime from tensorboardX import SummaryWriter @@ -25,12 +26,15 @@ from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.transform.trans_dict import TransformDict +from pymic.util.post_process import PostProcessDict from pymic.util.image_process import convert_label class SegmentationAgent(NetRunAgent): def __init__(self, config, stage = 'train'): super(SegmentationAgent, self).__init__(config, stage) - self.transform_dict = TransformDict + self.transform_dict = TransformDict + self.postprocess_dict = PostProcessDict + self.postprocessor = None def get_stage_dataset_from_config(self, stage): assert(stage in ['train', 'valid', 'test']) @@ -154,10 +158,13 @@ def get_loss_value(self, data, pred, gt, param = None): loss_value = self.loss_calculator(loss_input_dict) return loss_value + def set_postprocessor(self, postprocessor): + self.postprocessor = postprocessor + def training(self): class_num = self.config['network']['class_num'] iter_valid = self.config['training']['iter_valid'] - train_loss = 0 + train_loss = 0 train_dice_list = [] self.net.train() for it in range(iter_valid): @@ -192,10 +199,11 @@ def training(self): # forward + backward + optimize outputs = self.net(inputs) loss = self.get_loss_value(data, outputs, labels_prob) - # if (self.config['training']['use']) loss.backward() self.optimizer.step() - self.scheduler.step() + if(self.scheduler is not None and \ + not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): + self.scheduler.step() train_loss = train_loss + loss.item() # get dice evaluation for each class @@ -251,15 +259,19 @@ def validation(self): valid_cls_dice = np.asarray(valid_dice_list).mean(axis = 0) valid_avg_dice = valid_cls_dice.mean() + if(isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): + self.scheduler.step(valid_avg_dice) + valid_scalers = {'loss': valid_avg_loss, 'avg_dice': valid_avg_dice,\ 'class_dice': valid_cls_dice} return valid_scalers - def write_scalars(self, train_scalars, valid_scalars, glob_it): + def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): loss_scalar ={'train':train_scalars['loss'], 'valid':valid_scalars['loss']} dice_scalar ={'train':train_scalars['avg_dice'], 'valid':valid_scalars['avg_dice']} self.summ_writer.add_scalars('loss', loss_scalar, glob_it) self.summ_writer.add_scalars('dice', dice_scalar, glob_it) + self.summ_writer.add_scalars('lr', {"lr": lr_value}, glob_it) class_num = self.config['network']['class_num'] for c in range(class_num): cls_dice_scalar = {'train':train_scalars['class_dice'][c], \ @@ -282,24 +294,27 @@ def train_valid(self): self.device = torch.device("cuda:{0:}".format(device_ids[0])) self.net.to(self.device) ckpt_dir = self.config['training']['ckpt_save_dir'] - if(ckpt_dir[-1] == "/"): - ckpt_dir = ckpt_dir[:-1] - ckpt_prefx = ckpt_dir.split('/')[-1] + ckpt_prefix = self.config['training'].get('ckpt_prefix', None) + if(ckpt_prefix is None): + ckpt_prefix = ckpt_dir.split('/')[-1] iter_start = self.config['training']['iter_start'] iter_max = self.config['training']['iter_max'] iter_valid = self.config['training']['iter_valid'] - iter_save = self.config['training']['iter_save'] - if(isinstance(iter_save, (tuple, list))): + iter_save = self.config['training'].get('iter_save', None) + early_stop_it = self.config['training'].get('early_stop_patience', None) + if(iter_save is None): + iter_save_list = [iter_max] + elif(isinstance(iter_save, (tuple, list))): iter_save_list = iter_save else: - iter_save_list = range(iter_start, iter_max +1, iter_save) + iter_save_list = range(0, iter_max + 1, iter_save) self.max_val_dice = 0.0 self.max_val_it = 0 self.best_model_wts = None self.checkpoint = None if(iter_start > 0): - checkpoint_file = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefx, iter_start) + checkpoint_file = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, iter_start) self.checkpoint = torch.load(checkpoint_file, map_location = self.device) # assert(self.checkpoint['iteration'] == iter_start) if(len(device_ids) > 1): @@ -320,15 +335,18 @@ def train_valid(self): self.summ_writer = SummaryWriter(self.config['training']['ckpt_save_dir']) self.glob_it = iter_start for it in range(iter_start, iter_max, iter_valid): + lr_value = self.optimizer.param_groups[0]['lr'] t0 = time.time() train_scalars = self.training() t1 = time.time() + valid_scalars = self.validation() t2 = time.time() self.glob_it = it + iter_valid - logging.info("{0:} it {1:}".format(str(datetime.now())[:-7], self.glob_it)) + logging.info("\n{0:} it {1:}".format(str(datetime.now())[:-7], self.glob_it)) + logging.info('learning rate {0:}'.format(lr_value)) logging.info("training/validation time: {0:.2f}s/{1:.2f}s".format(t1-t0, t2-t1)) - self.write_scalars(train_scalars, valid_scalars, self.glob_it) + self.write_scalars(train_scalars, valid_scalars, lr_value, self.glob_it) if(valid_scalars['avg_dice'] > self.max_val_dice): self.max_val_dice = valid_scalars['avg_dice'] self.max_val_it = self.glob_it @@ -337,25 +355,30 @@ def train_valid(self): else: self.best_model_wts = copy.deepcopy(self.net.state_dict()) - if (self.glob_it in iter_save_list): + stop_now = True if(early_stop_it is not None and \ + self.glob_it - self.max_val_it > early_stop_it) else False + if ((self.glob_it in iter_save_list) or stop_now): save_dict = {'iteration': self.glob_it, 'valid_pred': valid_scalars['avg_dice'], 'model_state_dict': self.net.module.state_dict() \ if len(device_ids) > 1 else self.net.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict()} - save_name = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefx, self.glob_it) + save_name = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, self.glob_it) torch.save(save_dict, save_name) - txt_file = open("{0:}/{1:}_latest.txt".format(ckpt_dir, ckpt_prefx), 'wt') + txt_file = open("{0:}/{1:}_latest.txt".format(ckpt_dir, ckpt_prefix), 'wt') txt_file.write(str(self.glob_it)) txt_file.close() + if(stop_now): + logging.info("The training is early stopped") + break # save the best performing checkpoint save_dict = {'iteration': self.max_val_it, 'valid_pred': self.max_val_dice, 'model_state_dict': self.best_model_wts, 'optimizer_state_dict': self.optimizer.state_dict()} - save_name = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefx, self.max_val_it) + save_name = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, self.max_val_it) torch.save(save_dict, save_name) - txt_file = open("{0:}/{1:}_best.txt".format(ckpt_dir, ckpt_prefx), 'wt') + txt_file = open("{0:}/{1:}_best.txt".format(ckpt_dir, ckpt_prefix), 'wt') txt_file.write(str(self.max_val_it)) txt_file.close() logging.info('The best performing iter is {0:}, valid dice {1:}'.format(\ @@ -394,6 +417,9 @@ def test_time_dropout(m): infer_cfg = self.config['testing'] infer_cfg['class_num'] = self.config['network']['class_num'] self.inferer = Inferer(infer_cfg) + postpro_name = self.config['testing'].get('post_process', None) + if(self.postprocessor is None and postpro_name is not None): + self.postprocessor = PostProcessDict[postpro_name](self.config['testing']) infer_time_list = [] with torch.no_grad(): for data in self.test_loader: @@ -428,7 +454,7 @@ def test_time_dropout(m): self.save_ouputs(data) infer_time_list = np.asarray(infer_time_list) time_avg, time_std = infer_time_list.mean(), infer_time_list.std() - print("testing time {0:} +/- {1:}".format(time_avg, time_std)) + logging.info("testing time {0:} +/- {1:}".format(time_avg, time_std)) def infer_with_multiple_checkpoints(self): """ @@ -482,7 +508,7 @@ def infer_with_multiple_checkpoints(self): self.save_ouputs(data) infer_time_list = np.asarray(infer_time_list) time_avg, time_std = infer_time_list.mean(), infer_time_list.std() - print("testing time {0:} +/- {1:}".format(time_avg, time_std)) + logging.info("testing time {0:} +/- {1:}".format(time_avg, time_std)) def save_ouputs(self, data): output_dir = self.config['testing']['output_dir'] @@ -493,7 +519,7 @@ def save_ouputs(self, data): filename_replace_source = self.config['testing'].get('filename_replace_source', None) filename_replace_target = self.config['testing'].get('filename_replace_target', None) if(not os.path.exists(output_dir)): - os.mkdir(output_dir) + os.makedirs(output_dir, exist_ok=True) names, pred = data['names'], data['predict'] if(isinstance(pred, (list, tuple))): @@ -502,6 +528,9 @@ def save_ouputs(self, data): output = np.asarray(np.argmax(prob, axis = 1), np.uint8) if((label_source is not None) and (label_target is not None)): output = convert_label(output, label_source, label_target) + if(self.postprocessor is not None): + for i in range(len(names)): + output[i] = self.postprocessor(output[i]) # save the output and (optionally) probability predictions root_dir = self.config['dataset']['root_dir'] for i in range(len(names)): diff --git a/pymic/net_run/get_optimizer.py b/pymic/net_run/get_optimizer.py index 7170b6e..c4504de 100644 --- a/pymic/net_run/get_optimizer.py +++ b/pymic/net_run/get_optimizer.py @@ -2,33 +2,56 @@ from __future__ import print_function, division import torch -import torch.optim as optim +from torch import optim +from torch.optim import lr_scheduler +from pymic.util.general import keyword_match -def get_optimiser(name, net_params, optim_params): +def get_optimizer(name, net_params, optim_params): lr = optim_params['learning_rate'] momentum = optim_params['momentum'] weight_decay = optim_params['weight_decay'] - if(name == "SGD"): + if(keyword_match(name, "SGD")): return optim.SGD(net_params, lr, momentum = momentum, weight_decay = weight_decay) - elif(name == "Adam"): + elif(keyword_match(name, "Adam")): return optim.Adam(net_params, lr, weight_decay = weight_decay) - elif(name == "SparseAdam"): + elif(keyword_match(name, "SparseAdam")): return optim.SparseAdam(net_params, lr) - elif(name == "Adadelta"): + elif(keyword_match(name, "Adadelta")): return optim.Adadelta(net_params, lr, weight_decay = weight_decay) - elif(name == "Adagrad"): + elif(keyword_match(name, "Adagrad")): return optim.Adagrad(net_params, lr, weight_decay = weight_decay) - elif(name == "Adamax"): + elif(keyword_match(name, "Adamax")): return optim.Adamax(net_params, lr, weight_decay = weight_decay) - elif(name == "ASGD"): + elif(keyword_match(name, "ASGD")): return optim.ASGD(net_params, lr, weight_decay = weight_decay) - elif(name == "LBFGS"): + elif(keyword_match(name, "LBFGS")): return optim.LBFGS(net_params, lr) - elif(name == "RMSprop"): + elif(keyword_match(name, "RMSprop")): return optim.RMSprop(net_params, lr, momentum = momentum, weight_decay = weight_decay) - elif(name == "Rprop"): + elif(keyword_match(name, "Rprop")): return optim.Rprop(net_params, lr) else: raise ValueError("unsupported optimizer {0:}".format(name)) + + +def get_lr_scheduler(optimizer, sched_params): + name = sched_params["lr_scheduler"] + if(name is None): + return None + lr_gamma = sched_params["lr_gamma"] + if(keyword_match(name, "ReduceLROnPlateau")): + patience_it = sched_params["ReduceLROnPlateau_patience".lower()] + val_it = sched_params["iter_valid"] + patience = patience_it / val_it + scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, + mode = "max", factor=lr_gamma, patience = patience) + elif(keyword_match(name, "MultiStepLR")): + lr_milestones = sched_params["lr_milestones"] + last_iter = sched_params["last_iter"] + scheduler = lr_scheduler.MultiStepLR(optimizer, + lr_milestones, lr_gamma, last_iter) + else: + raise ValueError("unsupported lr scheduler {0:}".format(name)) + return scheduler \ No newline at end of file diff --git a/pymic/net_run/infer_func.py b/pymic/net_run/infer_func.py index e603725..78184fe 100644 --- a/pymic/net_run/infer_func.py +++ b/pymic/net_run/infer_func.py @@ -131,24 +131,26 @@ def run(self, model, image): tta_mode = self.config.get('tta_mode', 0) if(tta_mode == 0): outputs = self.__infer(image) - elif(tta_mode == 1): # test time augmentation with flip in 2D + elif(tta_mode == 1): + # test time augmentation with flip in 2D + # you may define your own method for test time augmentation outputs1 = self.__infer(image) outputs2 = self.__infer(torch.flip(image, [-2])) - outputs3 = self.__infer(torch.flip(image, [-3])) - outputs4 = self.__infer(torch.flip(image, [-2, -3])) + outputs3 = self.__infer(torch.flip(image, [-1])) + outputs4 = self.__infer(torch.flip(image, [-2, -1])) if(isinstance(outputs1, (tuple, list))): outputs = [] - for i in range(len(outputs)): + for i in range(len(outputs1)): temp_out1 = outputs1[i] temp_out2 = torch.flip(outputs2[i], [-2]) - temp_out3 = torch.flip(outputs3[i], [-3]) - temp_out4 = torch.flip(outputs4[i], [-2, -3]) + temp_out3 = torch.flip(outputs3[i], [-1]) + temp_out4 = torch.flip(outputs4[i], [-2, -1]) temp_mean = (temp_out1 + temp_out2 + temp_out3 + temp_out4) / 4 outputs.append(temp_mean) else: outputs2 = torch.flip(outputs2, [-2]) - outputs3 = torch.flip(outputs3, [-3]) - outputs4 = torch.flip(outputs4, [-2, -3]) + outputs3 = torch.flip(outputs3, [-1]) + outputs4 = torch.flip(outputs4, [-2, -1]) outputs = (outputs1 + outputs2 + outputs3 + outputs4) / 4 else: raise ValueError("Undefined tta_mode {0:}".format(tta_mode)) diff --git a/pymic/net_run/net_run.py b/pymic/net_run/net_run.py index 4ec1ce7..4c953ad 100644 --- a/pymic/net_run/net_run.py +++ b/pymic/net_run/net_run.py @@ -10,7 +10,7 @@ def main(): if(len(sys.argv) < 3): print('Number of arguments should be 3. e.g.') - print(' pymic_net_run train config.cfg') + print(' pymic_run train config.cfg') exit() stage = str(sys.argv[1]) cfg_file = str(sys.argv[2]) @@ -18,8 +18,8 @@ def main(): config = synchronize_config(config) log_dir = config['training']['ckpt_save_dir'] if(not os.path.exists(log_dir)): - os.mkdir(log_dir) - logging.basicConfig(filename=log_dir+"/log.txt", level=logging.INFO, + os.makedirs(log_dir, exist_ok=True) + logging.basicConfig(filename=log_dir+"/log_{0:}.txt".format(stage), level=logging.INFO, format='%(message)s') logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) logging_config(config) diff --git a/pymic/net_run_nll/__init__.py b/pymic/net_run_nll/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pymic/net_run_noise/cl.py b/pymic/net_run_nll/nll_cl.py similarity index 85% rename from pymic/net_run_noise/cl.py rename to pymic/net_run_nll/nll_cl.py index de31e3e..8173471 100644 --- a/pymic/net_run_noise/cl.py +++ b/pymic/net_run_nll/nll_cl.py @@ -14,6 +14,7 @@ import sys import torch import numpy as np +import pandas as pd import torch.nn as nn import torchvision.transforms as transforms from PIL import Image @@ -45,9 +46,9 @@ def get_confident_map(gt, pred, CL_type = 'both'): noise = cleanlab.pruning.get_noise_indices(gt, prob, prune_method=CL_type, n_jobs=1) return noise -class SegmentationAgentwithCL(SegmentationAgent): +class NLLConfidentLearn(SegmentationAgent): def __init__(self, config, stage = 'test'): - super(SegmentationAgentwithCL, self).__init__(config, stage) + super(NLLConfidentLearn, self).__init__(config, stage) def infer_with_cl(self): device_ids = self.config['testing']['gpus'] @@ -93,16 +94,6 @@ def test_time_dropout(m): filename_list.append(names) images = images.to(device) - # for debug - # for i in range(images.shape[0]): - # image_i = images[i][0] - # label_i = images[i][0] - # image_name = "temp/{0:}_image.nii.gz".format(names[0]) - # label_name = "temp/{0:}_label.nii.gz".format(names[0]) - # save_nd_array_as_image(image_i, image_name, reference_name = None) - # save_nd_array_as_image(label_i, label_name, reference_name = None) - # continue - pred = self.inferer.run(self.net, images) # convert tensor to numpy if(isinstance(pred, (tuple, list))): @@ -142,15 +133,10 @@ def test_time_dropout(m): dst_path = os.path.join(save_dir, filename) conf_map.save(dst_path) - def run(self): - self.create_dataset() - self.create_network() - self.infer_with_cl() - -def main(): +def get_confidence_map(): if(len(sys.argv) < 2): print('Number of arguments should be 3. e.g.') - print(' python cl.py config.cfg') + print(' python nll_cl.py config.cfg') exit() cfg_file = str(sys.argv[1]) config = parse_config(cfg_file) @@ -172,17 +158,35 @@ def main(): transform_list.append(one_transform) data_transform = transforms.Compose(transform_list) print('transform list', transform_list) - csv_file = config['dataset']['train_csv'] + csv_file = config['dataset']['train_csv'] + modal_num = config['dataset'].get('modal_num', 1) dataset = NiftyDataset(root_dir = config['dataset']['root_dir'], csv_file = csv_file, - modal_num = config['dataset']['modal_num'], + modal_num = modal_num, with_label= True, transform = data_transform ) - agent = SegmentationAgentwithCL(config, 'test') + agent = NLLConfidentLearn(config, 'test') agent.set_datasets(None, None, dataset) agent.transform_list = transform_list - agent.run() + agent.create_dataset() + agent.create_network() + agent.infer_with_cl() + + # create training csv for confidence learning + df_train = pd.read_csv(csv_file) + pixel_weight = [] + weight_dir = config['testing']['output_dir'] + "_conf" + for i in range(len(df_train["label"])): + lab_name = df_train["label"][i].split('/')[-1] + weight_name = "../" + weight_dir + '/' + lab_name + pixel_weight.append(weight_name) + train_cl_dict = {"image": df_train["image"], + "pixel_weight": pixel_weight, + "label": df_train["label"]} + train_cl_csv = csv_file.replace(".csv", "_cl.csv") + df_cl = pd.DataFrame.from_dict(train_cl_dict) + df_cl.to_csv(train_cl_csv, index = False) if __name__ == "__main__": - main() \ No newline at end of file + get_confidence_map() \ No newline at end of file diff --git a/pymic/net_run_nll/nll_clslsr.py b/pymic/net_run_nll/nll_clslsr.py new file mode 100644 index 0000000..9ee7182 --- /dev/null +++ b/pymic/net_run_nll/nll_clslsr.py @@ -0,0 +1,192 @@ +# -*- coding: utf-8 -*- +""" +Caculating the confidence map of labels of training samples, +which is used in the method of SLSR. + Minqing Zhang et al., Characterizing Label Errors: Confident Learning + for Noisy-Labeled Image Segmentation, MICCAI 2020. + https://link.springer.com/chapter/10.1007/978-3-030-59710-8_70 +""" + +from __future__ import print_function, division +import cleanlab +import logging +import os +import scipy +import sys +import torch +import numpy as np +import pandas as pd +import torch.nn as nn +import torchvision.transforms as transforms +from PIL import Image +from pymic.io.nifty_dataset import NiftyDataset +from pymic.transform.trans_dict import TransformDict +from pymic.util.parse_config import * +from pymic.net_run.agent_seg import SegmentationAgent +from pymic.net_run.infer_func import Inferer + +def get_confident_map(gt, pred, CL_type = 'both'): + """ + gt: ground truth label (one-hot) with shape of NXC + pred: digit prediction of network with shape of NXC + """ + prob = scipy.special.softmax(pred, axis = 1) + if CL_type in ['both', 'Qij']: + noise = cleanlab.pruning.get_noise_indices(gt, prob, prune_method='both', n_jobs=1) + elif CL_type == 'Cij': + noise = cleanlab.pruning.get_noise_indices(gt, pred, prune_method='both', n_jobs=1) + elif CL_type == 'intersection': + noise_qij = cleanlab.pruning.get_noise_indices(gt, prob, prune_method='both', n_jobs=1) + noise_cij = cleanlab.pruning.get_noise_indices(gt, pred, prune_method='both', n_jobs=1) + noise = noise_qij & noise_cij + elif CL_type == 'union': + noise_qij = cleanlab.pruning.get_noise_indices(gt, prob, prune_method='both', n_jobs=1) + noise_cij = cleanlab.pruning.get_noise_indices(gt, pred, prune_method='both', n_jobs=1) + noise = noise_qij | noise_cij + elif CL_type in ['prune_by_class', 'prune_by_noise_rate']: + noise = cleanlab.pruning.get_noise_indices(gt, prob, prune_method=CL_type, n_jobs=1) + return noise + +class NLLCLSLSR(SegmentationAgent): + def __init__(self, config, stage = 'test'): + super(NLLCLSLSR, self).__init__(config, stage) + + def infer_with_cl(self): + device_ids = self.config['testing']['gpus'] + device = torch.device("cuda:{0:}".format(device_ids[0])) + self.net.to(device) + + if(self.config['testing'].get('evaluation_mode', True)): + self.net.eval() + if(self.config['testing'].get('test_time_dropout', False)): + def test_time_dropout(m): + if(type(m) == nn.Dropout): + logging.info('dropout layer') + m.train() + self.net.apply(test_time_dropout) + + ckpt_mode = self.config['testing']['ckpt_mode'] + ckpt_name = self.get_checkpoint_name() + if(ckpt_mode == 3): + assert(isinstance(ckpt_name, (tuple, list))) + self.infer_with_multiple_checkpoints() + return + else: + if(isinstance(ckpt_name, (tuple, list))): + raise ValueError("ckpt_mode should be 3 if ckpt_name is a list") + + # load network parameters and set the network as evaluation mode + checkpoint = torch.load(ckpt_name, map_location = device) + self.net.load_state_dict(checkpoint['model_state_dict']) + + if(self.inferer is None): + infer_cfg = self.config['testing'] + class_num = self.config['network']['class_num'] + infer_cfg['class_num'] = class_num + self.inferer = Inferer(infer_cfg) + pred_list = [] + gt_list = [] + filename_list = [] + with torch.no_grad(): + for data in self.test_loader: + images = self.convert_tensor_type(data['image']) + labels = self.convert_tensor_type(data['label_prob']) + names = data['names'] + filename_list.append(names) + images = images.to(device) + + pred = self.inferer.run(self.net, images) + # convert tensor to numpy + if(isinstance(pred, (tuple, list))): + pred = [item.cpu().numpy() for item in pred] + else: + pred = pred.cpu().numpy() + data['predict'] = pred + # inverse transform + for transform in self.transform_list[::-1]: + if (transform.inverse): + data = transform.inverse_transform_for_prediction(data) + + pred = data['predict'] + # conver prediction from N, C, H, W to (N*H*W)*C + print(names, pred.shape, labels.shape) + pred_2d = np.swapaxes(pred, 1, 2) + pred_2d = np.swapaxes(pred_2d, 2, 3) + pred_2d = pred_2d.reshape(-1, class_num) + lab = labels.cpu().numpy() + lab_2d = np.swapaxes(lab, 1, 2) + lab_2d = np.swapaxes(lab_2d, 2, 3) + lab_2d = lab_2d.reshape(-1, class_num) + pred_list.append(pred_2d) + gt_list.append(lab_2d) + + pred_cat = np.concatenate(pred_list) + gt_cat = np.concatenate(gt_list) + gt = np.argmax(gt_cat, axis = 1) + gt = gt.reshape(-1).astype(np.uint8) + print(gt.shape, pred_cat.shape) + conf = get_confident_map(gt, pred_cat) + conf = conf.reshape(-1, 256, 256).astype(np.uint8) * 255 + save_dir = self.config['dataset']['root_dir'] + "/slsr_conf" + for idx in range(len(filename_list)): + filename = filename_list[idx][0].split('/')[-1] + conf_map = Image.fromarray(conf[idx]) + dst_path = os.path.join(save_dir, filename) + conf_map.save(dst_path) + +def get_confidence_map(): + if(len(sys.argv) < 2): + print('Number of arguments should be 3. e.g.') + print(' python nll_cl.py config.cfg') + exit() + cfg_file = str(sys.argv[1]) + config = parse_config(cfg_file) + config = synchronize_config(config) + + # set dataset + transform_names = config['dataset']['valid_transform'] + transform_list = [] + transform_dict = TransformDict + if(transform_names is None or len(transform_names) == 0): + data_transform = None + else: + transform_param = config['dataset'] + transform_param['task'] = 'segmentation' + for name in transform_names: + if(name not in transform_dict): + raise(ValueError("Undefined transform {0:}".format(name))) + one_transform = transform_dict[name](transform_param) + transform_list.append(one_transform) + data_transform = transforms.Compose(transform_list) + print('transform list', transform_list) + csv_file = config['dataset']['train_csv'] + modal_num = config['dataset'].get('modal_num', 1) + dataset = NiftyDataset(root_dir = config['dataset']['root_dir'], + csv_file = csv_file, + modal_num = modal_num, + with_label= True, + transform = data_transform ) + + agent = NLLCLSLSR(config, 'test') + agent.set_datasets(None, None, dataset) + agent.transform_list = transform_list + agent.create_dataset() + agent.create_network() + agent.infer_with_cl() + + # create training csv for confidence learning + df_train = pd.read_csv(csv_file) + pixel_weight = [] + for i in range(len(df_train["label"])): + lab_name = df_train["label"][i].split('/')[-1] + weight_name = "slsr_conf/" + lab_name + pixel_weight.append(weight_name) + train_cl_dict = {"image": df_train["image"], + "pixel_weight": pixel_weight, + "label": df_train["label"]} + train_cl_csv = csv_file.replace(".csv", "_clslsr.csv") + df_cl = pd.DataFrame.from_dict(train_cl_dict) + df_cl.to_csv(train_cl_csv, index = False) + +if __name__ == "__main__": + get_confidence_map() \ No newline at end of file diff --git a/pymic/net_run_noise/co_teaching.py b/pymic/net_run_nll/nll_co_teaching.py similarity index 64% rename from pymic/net_run_noise/co_teaching.py rename to pymic/net_run_nll/nll_co_teaching.py index 228e3bd..bcaec4e 100644 --- a/pymic/net_run_noise/co_teaching.py +++ b/pymic/net_run_nll/nll_co_teaching.py @@ -11,76 +11,67 @@ """ from __future__ import print_function, division import logging +import os +import sys import numpy as np import torch import torch.nn as nn import torch.optim as optim +from torch.optim import lr_scheduler from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.loss.seg.util import reshape_tensor_to_2D -from pymic.util.ramps import sigmoid_rampup -from pymic.net_run.get_optimizer import get_optimiser from pymic.net_run.agent_seg import SegmentationAgent from pymic.net.net_dict_seg import SegNetDict - -import logging -import os -import sys from pymic.util.parse_config import * +from pymic.util.ramps import get_rampup_ratio + +class BiNet(nn.Module): + def __init__(self, params): + super(BiNet, self).__init__() + net_name = params['net_type'] + self.net1 = SegNetDict[net_name](params) + self.net2 = SegNetDict[net_name](params) -class CoTeachingAgent(SegmentationAgent): + def forward(self, x): + out1 = self.net1(x) + out2 = self.net2(x) + + if(self.training): + return out1, out2 + else: + return (out1 + out2) / 3 + +class NLLCoTeaching(SegmentationAgent): """ - Using cross pseudo supervision according to the following paper: - Xiaokang Chen, Yuhui Yuan, Gang Zeng, Jingdong Wang, - Semi-Supervised Semantic Segmentation with Cross Pseudo Supervision, - CVPR 2021, pp. 2613-2022. - https://arxiv.org/abs/2106.01226 + Co-teaching: Robust Training of Deep Neural Networks with Extremely + Noisy Labels + https://arxiv.org/abs/1804.06872 """ def __init__(self, config, stage = 'train'): - super(CoTeachingAgent, self).__init__(config, stage) - self.net2 = None - self.optimizer2 = None - self.scheduler2 = None + super(NLLCoTeaching, self).__init__(config, stage) loss_type = config['training']["loss_type"] if(loss_type != "CrossEntropyLoss"): logging.warn("only CrossEntropyLoss supported for" + " coteaching, the specified loss {0:} is ingored".format(loss_type)) def create_network(self): - super(CoTeachingAgent, self).create_network() - if(self.net2 is None): - net_name = self.config['network']['net_type'] - if(net_name not in SegNetDict): - raise ValueError("Undefined network {0:}".format(net_name)) - self.net2 = SegNetDict[net_name](self.config['network']) + if(self.net is None): + self.net = BiNet(self.config['network']) if(self.tensor_type == 'float'): - self.net2.float() + self.net.float() else: - self.net2.double() - - def train_valid(self): - # create optimizor for the second network - if(self.optimizer2 is None): - self.optimizer2 = get_optimiser(self.config['training']['optimizer'], - self.net2.parameters(), - self.config['training']) - last_iter = -1 - # if(self.checkpoint is not None): - # self.optimizer2.load_state_dict(self.checkpoint['optimizer_state_dict']) - # last_iter = self.checkpoint['iteration'] - 1 - if(self.scheduler2 is None): - self.scheduler2 = optim.lr_scheduler.MultiStepLR(self.optimizer2, - self.config['training']['lr_milestones'], - self.config['training']['lr_gamma'], - last_epoch = last_iter) - super(CoTeachingAgent, self).train_valid() + self.net.double() def training(self): class_num = self.config['network']['class_num'] iter_valid = self.config['training']['iter_valid'] - select_ratio = self.config['training']['co_teaching_select_ratio'] - rampup_length = self.config['training']['co_teaching_rampup_length'] + nll_cfg = self.config['noisy_label_learning'] + select_ratio = nll_cfg['co_teaching_select_ratio'] + iter_max = self.config['training']['iter_max'] + rampup_start = nll_cfg.get('rampup_start', 0) + rampup_end = nll_cfg.get('rampup_end', iter_max) train_loss_no_select1 = 0 train_loss_no_select2 = 0 @@ -88,8 +79,6 @@ def training(self): train_loss2 = 0 train_dice_list = [] self.net.train() - self.net2.train() - self.net2.to(self.device) for it in range(iter_valid): try: data = next(self.trainIter) @@ -104,11 +93,9 @@ def training(self): # zero the parameter gradients self.optimizer.zero_grad() - self.optimizer2.zero_grad() # forward + backward + optimize - outputs1 = self.net(inputs) - outputs2 = self.net2(inputs) + outputs1, outputs2 = self.net(inputs) prob1 = nn.Softmax(dim = 1)(outputs1) prob2 = nn.Softmax(dim = 1)(outputs2) @@ -124,8 +111,9 @@ def training(self): loss2 = torch.sum(loss2, dim = 1) # shape is [N] ind_2_sorted = torch.argsort(loss2) - forget_ratio = (1 - select_ratio) * self.glob_it / rampup_length - remb_ratio = max(select_ratio, 1 - forget_ratio) + rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") + forget_ratio = (1 - select_ratio) * rampup_ratio + remb_ratio = 1 - forget_ratio num_remb = int(remb_ratio * len(loss1)) ind_1_update = ind_1_sorted[:num_remb] @@ -136,22 +124,17 @@ def training(self): loss = loss1_select.mean() + loss2_select.mean() - # if (self.config['training']['use']) loss.backward() self.optimizer.step() - self.scheduler.step() - self.optimizer2.step() - self.scheduler2.step() + if(self.scheduler is not None and \ + not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): + self.scheduler.step() train_loss_no_select1 = train_loss_no_select1 + loss1.mean().item() train_loss_no_select2 = train_loss_no_select2 + loss2.mean().item() train_loss1 = train_loss1 + loss1_select.mean().item() train_loss2 = train_loss2 + loss2_select.mean().item() - # get dice evaluation for each class in annotated images - # if(isinstance(outputs1, tuple) or isinstance(outputs1, list)): - # outputs1 = outputs1[0] - outputs1_argmax = torch.argmax(outputs1, dim = 1, keepdim = True) soft_out1 = get_soft_label(outputs1_argmax, class_num, self.tensor_type) soft_out1, labels_prob = reshape_prediction_and_ground_truth(soft_out1, labels_prob) @@ -171,7 +154,7 @@ def training(self): 'select_ratio':remb_ratio, 'avg_dice':train_avg_dice, 'class_dice': train_cls_dice} return train_scalers - def write_scalars(self, train_scalars, valid_scalars, glob_it): + def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): loss_scalar ={'train':train_scalars['loss'], 'valid':valid_scalars['loss']} loss_no_select_scalar = {'net1':train_scalars['loss_no_select1'], @@ -181,6 +164,7 @@ def write_scalars(self, train_scalars, valid_scalars, glob_it): self.summ_writer.add_scalars('loss', loss_scalar, glob_it) self.summ_writer.add_scalars('loss_no_select', loss_no_select_scalar, glob_it) self.summ_writer.add_scalars('select_ratio', {'select_ratio':train_scalars['select_ratio']}, glob_it) + self.summ_writer.add_scalars('lr', {"lr": lr_value}, glob_it) self.summ_writer.add_scalars('dice', dice_scalar, glob_it) class_num = self.config['network']['class_num'] for c in range(class_num): @@ -194,22 +178,3 @@ def write_scalars(self, train_scalars, valid_scalars, glob_it): logging.info('valid loss {0:.4f}, avg dice {1:.4f} '.format( valid_scalars['loss'], valid_scalars['avg_dice']) + "[" + \ ' '.join("{0:.4f}".format(x) for x in valid_scalars['class_dice']) + "]") - -if __name__ == "__main__": - if(len(sys.argv) < 3): - print('Number of arguments should be 3. e.g.') - print(' pymic_ssl train config.cfg') - exit() - stage = str(sys.argv[1]) - cfg_file = str(sys.argv[2]) - config = parse_config(cfg_file) - config = synchronize_config(config) - log_dir = config['training']['ckpt_save_dir'] - if(not os.path.exists(log_dir)): - os.mkdir(log_dir) - logging.basicConfig(filename=log_dir+"/log.txt", level=logging.INFO, - format='%(message)s') - logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) - logging_config(config) - agent = CoTeachingAgent(config, stage) - agent.run() \ No newline at end of file diff --git a/pymic/net_run_nll/nll_dast.py b/pymic/net_run_nll/nll_dast.py new file mode 100644 index 0000000..19a59a2 --- /dev/null +++ b/pymic/net_run_nll/nll_dast.py @@ -0,0 +1,261 @@ +# -*- coding: utf-8 -*- +""" +Implementation of DAST for noise robust learning according to the following paper. + Shuojue Yang, Guotai Wang, Hui Sun, Xiangde Luo, Peng Sun, Kang Li, Qijun Wang, + Shaoting Zhang: Learning COVID-19 Pneumonia Lesion Segmentation from Imperfect + Annotations via Divergence-Aware Selective Training. + JBHI 2022. https://ieeexplore.ieee.org/document/9770406 +""" + +from __future__ import print_function, division +import random +import torch +import numpy as np +import torch.nn as nn +import torchvision.transforms as transforms +from torch.optim import lr_scheduler +from pymic.io.nifty_dataset import NiftyDataset +from pymic.loss.seg.util import get_soft_label +from pymic.loss.seg.util import reshape_prediction_and_ground_truth +from pymic.loss.seg.util import get_classwise_dice +from pymic.net_run.agent_seg import SegmentationAgent +from pymic.util.parse_config import * +from pymic.util.ramps import get_rampup_ratio + +class Rank(object): + """ + Dynamically rank the current training sample with specific metrics + """ + def __init__(self, quene_length = 100): + self.vals = [] + self.quene_length = quene_length + + def add_val(self, val): + """ + Update the quene and calculate the order of the input value. + + Return + --------- + rank: rank of the input value with a range of (0, self.quenen_length) + """ + if len(self.vals) < self.quene_length: + self.vals.append(val) + rank = -1 + else: + self.vals.pop(0) + self.vals.append(val) + assert len(self.vals) == self.quene_length + idxes = np.argsort(self.vals) + rank = np.where(idxes == self.quene_length-1)[0][0] + return rank + +class ConsistLoss(nn.Module): + def __init__(self): + super(ConsistLoss, self).__init__() + + def kl_div_map(self, input, label): + kl_map = torch.sum(label * (torch.log(label + 1e-16) - torch.log(input + 1e-16)), dim = 1) + return kl_map + + def kl_loss(self,input, target, size_average=True): + kl_div = self.kl_div_map(input, target) + if size_average: + return torch.mean(kl_div) + else: + return kl_div + + def forward(self, input1, input2, size_average = True): + kl1 = self.kl_loss(input1, input2.detach(), size_average=size_average) + kl2 = self.kl_loss(input2, input1.detach(), size_average=size_average) + return (kl1 + kl2) / 2 + +def get_ce(prob, soft_y, size_avg = True): + prob = prob * 0.999 + 5e-4 + ce = - soft_y* torch.log(prob) + ce = torch.sum(ce, dim = 1) # shape is [N] + if(size_avg): + ce = torch.mean(ce) + return ce + +@torch.no_grad() +def select_criterion(no_noisy_sample, cl_noisy_sample, label): + """ + no_noisy_sample: noisy branch's output probability for noisy sample + cl_noisy_sample: clean branch's output probability for noisy sample + label: noisy label + """ + l_n = get_ce(no_noisy_sample, label, size_avg = False) + l_c = get_ce(cl_noisy_sample, label, size_avg = False) + js_distance = ConsistLoss() + variance = js_distance(no_noisy_sample, cl_noisy_sample, size_average=False) + exp_variance = torch.exp(-16 * variance) + loss_n = torch.mean(l_c * exp_variance).item() + loss_c = torch.mean(l_n * exp_variance).item() + return loss_n, loss_c + +class NLLDAST(SegmentationAgent): + def __init__(self, config, stage = 'train'): + super(NLLDAST, self).__init__(config, stage) + self.train_set_noise = None + self.train_loader_noise = None + self.trainIter_noise = None + self.noisy_rank = None + self.clean_rank = None + + def get_noisy_dataset_from_config(self): + root_dir = self.config['dataset']['root_dir'] + modal_num = self.config['dataset'].get('modal_num', 1) + transform_names = self.config['dataset']['train_transform'] + + self.transform_list = [] + if(transform_names is None or len(transform_names) == 0): + data_transform = None + else: + transform_param = self.config['dataset'] + transform_param['task'] = 'segmentation' + for name in transform_names: + if(name not in self.transform_dict): + raise(ValueError("Undefined transform {0:}".format(name))) + one_transform = self.transform_dict[name](transform_param) + self.transform_list.append(one_transform) + data_transform = transforms.Compose(self.transform_list) + + csv_file = self.config['dataset'].get('train_csv_noise', None) + dataset = NiftyDataset(root_dir=root_dir, + csv_file = csv_file, + modal_num = modal_num, + with_label= True, + transform = data_transform ) + return dataset + + def create_dataset(self): + super(NLLDAST, self).create_dataset() + if(self.stage == 'train'): + if(self.train_set_noise is None): + self.train_set_noise = self.get_noisy_dataset_from_config() + if(self.deterministic): + def worker_init_fn(worker_id): + random.seed(self.random_seed + worker_id) + worker_init = worker_init_fn + else: + worker_init = None + + bn_train_noise = self.config['dataset']['train_batch_size_noise'] + num_worker = self.config['dataset'].get('num_workder', 16) + self.train_loader_noise = torch.utils.data.DataLoader(self.train_set_noise, + batch_size = bn_train_noise, shuffle=True, num_workers= num_worker, + worker_init_fn=worker_init) + + def training(self): + class_num = self.config['network']['class_num'] + iter_valid = self.config['training']['iter_valid'] + nll_cfg = self.config['noisy_label_learning'] + iter_max = self.config['training']['iter_max'] + rampup_start = nll_cfg.get('rampup_start', 0) + rampup_end = nll_cfg.get('rampup_end', iter_max) + train_loss = 0 + train_loss_sup = 0 + train_loss_reg = 0 + train_dice_list = [] + self.net.train() + + rank_length = nll_cfg.get("dast_rank_length", 20) + consist_loss = ConsistLoss() + for it in range(iter_valid): + try: + data_cl = next(self.trainIter) + except StopIteration: + self.trainIter = iter(self.train_loader) + data_cl = next(self.trainIter) + try: + data_no = next(self.trainIter_noise) + except StopIteration: + self.trainIter_noise = iter(self.train_loader_noise) + data_no = next(self.trainIter_noise) + + # get the inputs + x0 = self.convert_tensor_type(data_cl['image']) # clean sample + y0 = self.convert_tensor_type(data_cl['label_prob']) + x1 = self.convert_tensor_type(data_no['image']) # noisy sample + y1 = self.convert_tensor_type(data_no['label_prob']) + inputs = torch.cat([x0, x1], dim = 0).to(self.device) + y0, y1 = y0.to(self.device), y1.to(self.device) + + # zero the parameter gradients + self.optimizer.zero_grad() + + # forward + backward + optimize + b0_pred, b1_pred = self.net(inputs) + n0 = list(x0.shape)[0] # number of clean samples + b0_x0_pred = b0_pred[:n0] # predication of clean samples from clean branch + b0_x1_pred = b0_pred[n0:] # predication of noisy samples from clean branch + b1_x1_pred = b1_pred[n0:] # predication of noisy samples from noisy branch + + # supervised loss for the clean and noisy branches, respectively + loss_sup_cl = self.get_loss_value(data_cl, b0_x0_pred, y0) + loss_sup_no = self.get_loss_value(data_no, b1_x1_pred, y1) + loss_sup = (loss_sup_cl + loss_sup_no) / 2 + loss = loss_sup + + # Severe Noise supression & Supplementary Training + rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") + w_dbc = nll_cfg.get('dast_dbc_w', 0.1) * rampup_ratio + w_st = nll_cfg.get('dast_st_w', 0.1) * rampup_ratio + b1_x1_prob = nn.Softmax(dim = 1)(b1_x1_pred) + b0_x1_prob = nn.Softmax(dim = 1)(b0_x1_pred) + loss_n, loss_c = select_criterion(b1_x1_prob, b0_x1_prob, y1) + rank_n = self.noisy_rank.add_val(loss_n) + rank_c = self.clean_rank.add_val(loss_c) + if loss_n < loss_c: + select_ratio = nll_cfg.get('dast_select_ratio', 0.2) + if rank_c >= rank_length * (1 - select_ratio): + loss_dbc = consist_loss(b1_x1_prob, b0_x1_prob) + loss = loss + loss_dbc * w_dbc + if rank_n <= rank_length * select_ratio: + b0_x1_argmax = torch.argmax(b0_x1_pred, dim = 1, keepdim = True) + b0_x1_lab = get_soft_label(b0_x1_argmax, class_num, self.tensor_type) + b1_x1_argmax = torch.argmax(b1_x1_pred, dim = 1, keepdim = True) + b1_x1_lab = get_soft_label(b1_x1_argmax, class_num, self.tensor_type) + pseudo_label = (b0_x1_lab + b1_x1_lab + y1) / 3 + sharpen = lambda p,T: p**(1.0/T)/(p**(1.0/T) + (1-p)**(1.0/T)) + b0_x1_prob = nn.Softmax(dim = 1)(b0_x1_pred) + loss_st = torch.mean(torch.abs(b0_x1_prob - sharpen(pseudo_label, 0.5))) + loss = loss + loss_st * w_st + + loss.backward() + self.optimizer.step() + if(self.scheduler is not None and \ + not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): + self.scheduler.step() + + train_loss = train_loss + loss.item() + train_loss_sup = train_loss_sup + loss_sup.item() + # train_loss_reg = train_loss_reg + loss_reg.item() + # get dice evaluation for each class in annotated images + if(isinstance(b0_x0_pred, tuple) or isinstance(b0_x0_pred, list)): + p0 = b0_x0_pred[0] + else: + p0 = b0_x0_pred + p0_argmax = torch.argmax(p0, dim = 1, keepdim = True) + p0_soft = get_soft_label(p0_argmax, class_num, self.tensor_type) + p0_soft, y0 = reshape_prediction_and_ground_truth(p0_soft, y0) + dice_list = get_classwise_dice(p0_soft, y0) + train_dice_list.append(dice_list.cpu().numpy()) + train_avg_loss = train_loss / iter_valid + train_avg_loss_sup = train_loss_sup / iter_valid + train_avg_loss_reg = train_loss_reg / iter_valid + train_cls_dice = np.asarray(train_dice_list).mean(axis = 0) + train_avg_dice = train_cls_dice.mean() + + train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, + 'loss_reg':train_avg_loss_reg, 'regular_w':w_dbc, + 'avg_dice':train_avg_dice, 'class_dice': train_cls_dice} + return train_scalers + + def train_valid(self): + self.trainIter_noise = iter(self.train_loader_noise) + nll_cfg = self.config['noisy_label_learning'] + rank_length = nll_cfg.get("dast_rank_length", 20) + self.noisy_rank = Rank(rank_length) + self.clean_rank = Rank(rank_length) + super(NLLDAST, self).train_valid() \ No newline at end of file diff --git a/pymic/net_run_nll/nll_main.py b/pymic/net_run_nll/nll_main.py new file mode 100644 index 0000000..cc07a44 --- /dev/null +++ b/pymic/net_run_nll/nll_main.py @@ -0,0 +1,39 @@ + +# -*- coding: utf-8 -*- +from __future__ import print_function, division +import logging +import os +import sys +from pymic.util.parse_config import * +from pymic.net_run_nll.nll_co_teaching import NLLCoTeaching +from pymic.net_run_nll.nll_trinet import NLLTriNet +from pymic.net_run_nll.nll_dast import NLLDAST + +NLLMethodDict = {'CoTeaching': NLLCoTeaching, + "TriNet": NLLTriNet, + "DAST": NLLDAST} + +def main(): + if(len(sys.argv) < 3): + print('Number of arguments should be 3. e.g.') + print(' pymic_nll train config.cfg') + exit() + stage = str(sys.argv[1]) + cfg_file = str(sys.argv[2]) + config = parse_config(cfg_file) + config = synchronize_config(config) + log_dir = config['training']['ckpt_save_dir'] + if(not os.path.exists(log_dir)): + os.mkdir(log_dir) + logging.basicConfig(filename=log_dir+"/log_{0:}.txt".format(stage), level=logging.INFO, + format='%(message)s') + logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) + logging_config(config) + nll_method = config['noisy_label_learning']['nll_method'] + agent = NLLMethodDict[nll_method](config, stage) + agent.run() + +if __name__ == "__main__": + main() + + \ No newline at end of file diff --git a/pymic/net_run_nll/nll_trinet.py b/pymic/net_run_nll/nll_trinet.py new file mode 100644 index 0000000..6af5449 --- /dev/null +++ b/pymic/net_run_nll/nll_trinet.py @@ -0,0 +1,171 @@ +# -*- coding: utf-8 -*- +""" +Implementation of trinet for learning from noisy samples for +segmentation tasks according to the following paper: + Tianwei Zhang, Lequan Yu, Na Hu, Su Lv, Shi Gu: + Robust Medical Image Segmentation from Non-expert Annotations with Tri-network. + MICCAI 2020. + https://link.springer.com/chapter/10.1007/978-3-030-59719-1_25 +""" +from __future__ import print_function, division +import logging +import os +import sys +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from torch.optim import lr_scheduler +from pymic.loss.seg.util import get_soft_label +from pymic.loss.seg.util import reshape_prediction_and_ground_truth +from pymic.loss.seg.util import get_classwise_dice +from pymic.loss.seg.util import reshape_tensor_to_2D +from pymic.net_run.agent_seg import SegmentationAgent +from pymic.net.net_dict_seg import SegNetDict +from pymic.util.parse_config import * +from pymic.util.ramps import get_rampup_ratio + + + +class TriNet(nn.Module): + def __init__(self, params): + super(TriNet, self).__init__() + net_name = params['net_type'] + self.net1 = SegNetDict[net_name](params) + self.net2 = SegNetDict[net_name](params) + self.net3 = SegNetDict[net_name](params) + + def forward(self, x): + out1 = self.net1(x) + out2 = self.net2(x) + out3 = self.net3(x) + + if(self.training): + return out1, out2, out3 + else: + return (out1 + out2 + out3) / 3 + +class NLLTriNet(SegmentationAgent): + def __init__(self, config, stage = 'train'): + super(NLLTriNet, self).__init__(config, stage) + + def create_network(self): + if(self.net is None): + self.net = TriNet(self.config['network']) + if(self.tensor_type == 'float'): + self.net.float() + else: + self.net.double() + + def get_loss_and_confident_mask(self, pred, labels_prob, conf_ratio): + prob = nn.Softmax(dim = 1)(pred) + prob_2d = reshape_tensor_to_2D(prob) * 0.999 + 5e-4 + y_2d = reshape_tensor_to_2D(labels_prob) + + loss = - y_2d* torch.log(prob_2d) + loss = torch.sum(loss, dim = 1) # shape is [N] + threshold = torch.quantile(loss, conf_ratio) + mask = loss < threshold + return loss, mask + + def training(self): + class_num = self.config['network']['class_num'] + iter_valid = self.config['training']['iter_valid'] + nll_cfg = self.config['noisy_label_learning'] + iter_max = self.config['training']['iter_max'] + select_ratio = nll_cfg['trinet_select_ratio'] + rampup_start = nll_cfg.get('rampup_start', 0) + rampup_end = nll_cfg.get('rampup_end', iter_max) + + train_loss_no_select1 = 0 + train_loss_no_select2 = 0 + train_loss1, train_loss2, train_loss3 = 0, 0, 0 + train_dice_list = [] + self.net.train() + for it in range(iter_valid): + try: + data = next(self.trainIter) + except StopIteration: + self.trainIter = iter(self.train_loader) + data = next(self.trainIter) + + # get the inputs + inputs = self.convert_tensor_type(data['image']) + labels_prob = self.convert_tensor_type(data['label_prob']) + inputs, labels_prob = inputs.to(self.device), labels_prob.to(self.device) + + # zero the parameter gradients + self.optimizer.zero_grad() + + # forward + backward + optimize + outputs1, outputs2, outputs3 = self.net(inputs) + + rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end) + forget_ratio = (1 - select_ratio) * rampup_ratio + remb_ratio = 1 - forget_ratio + + loss1, mask1 = self.get_loss_and_confident_mask(outputs1, labels_prob, remb_ratio) + loss2, mask2 = self.get_loss_and_confident_mask(outputs2, labels_prob, remb_ratio) + loss3, mask3 = self.get_loss_and_confident_mask(outputs3, labels_prob, remb_ratio) + mask12, mask13, mask23 = mask1 * mask2, mask1 * mask3, mask2 * mask3 + mask12, mask13, mask23 = mask12.detach(), mask13.detach(), mask23.detach() + + loss1_avg = torch.sum(loss1 * mask23) / mask23.sum() + loss2_avg = torch.sum(loss2 * mask13) / mask13.sum() + loss3_avg = torch.sum(loss3 * mask12) / mask12.sum() + loss = (loss1_avg + loss2_avg + loss3_avg) / 3 + + loss.backward() + self.optimizer.step() + if(self.scheduler is not None and \ + not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): + self.scheduler.step() + + train_loss_no_select1 = train_loss_no_select1 + loss1.mean().item() + train_loss_no_select2 = train_loss_no_select2 + loss2.mean().item() + train_loss1 = train_loss1 + loss1_avg.item() + train_loss2 = train_loss2 + loss2_avg.item() + + outputs1_argmax = torch.argmax(outputs1, dim = 1, keepdim = True) + soft_out1 = get_soft_label(outputs1_argmax, class_num, self.tensor_type) + soft_out1, labels_prob = reshape_prediction_and_ground_truth(soft_out1, labels_prob) + dice_list = get_classwise_dice(soft_out1, labels_prob).detach().cpu().numpy() + train_dice_list.append(dice_list) + train_avg_loss_no_select1 = train_loss_no_select1 / iter_valid + train_avg_loss_no_select2 = train_loss_no_select2 / iter_valid + train_avg_loss1 = train_loss1 / iter_valid + train_avg_loss2 = train_loss2 / iter_valid + train_cls_dice = np.asarray(train_dice_list).mean(axis = 0) + train_avg_dice = train_cls_dice.mean() + + train_scalers = {'loss': (train_avg_loss1 + train_avg_loss2) / 2, + 'loss1':train_avg_loss1, 'loss2': train_avg_loss2, + 'loss_no_select1':train_avg_loss_no_select1, + 'loss_no_select2':train_avg_loss_no_select2, + 'select_ratio':remb_ratio, 'avg_dice':train_avg_dice, 'class_dice': train_cls_dice} + return train_scalers + + def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): + loss_scalar ={'train':train_scalars['loss'], + 'valid':valid_scalars['loss']} + loss_no_select_scalar = {'net1':train_scalars['loss_no_select1'], + 'net2':train_scalars['loss_no_select2']} + + dice_scalar ={'train':train_scalars['avg_dice'], 'valid':valid_scalars['avg_dice']} + self.summ_writer.add_scalars('loss', loss_scalar, glob_it) + self.summ_writer.add_scalars('loss_no_select', loss_no_select_scalar, glob_it) + self.summ_writer.add_scalars('select_ratio', {'select_ratio':train_scalars['select_ratio']}, glob_it) + self.summ_writer.add_scalars('lr', {"lr": lr_value}, glob_it) + self.summ_writer.add_scalars('dice', dice_scalar, glob_it) + class_num = self.config['network']['class_num'] + for c in range(class_num): + cls_dice_scalar = {'train':train_scalars['class_dice'][c], \ + 'valid':valid_scalars['class_dice'][c]} + self.summ_writer.add_scalars('class_{0:}_dice'.format(c), cls_dice_scalar, glob_it) + + logging.info('train loss {0:.4f}, avg dice {1:.4f} '.format( + train_scalars['loss'], train_scalars['avg_dice']) + "[" + \ + ' '.join("{0:.4f}".format(x) for x in train_scalars['class_dice']) + "]") + logging.info('valid loss {0:.4f}, avg dice {1:.4f} '.format( + valid_scalars['loss'], valid_scalars['avg_dice']) + "[" + \ + ' '.join("{0:.4f}".format(x) for x in valid_scalars['class_dice']) + "]") diff --git a/pymic/net_run_ssl/__init__.py b/pymic/net_run_ssl/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pymic/net_run_ssl/ssl_abstract.py b/pymic/net_run_ssl/ssl_abstract.py new file mode 100644 index 0000000..1d18c4d --- /dev/null +++ b/pymic/net_run_ssl/ssl_abstract.py @@ -0,0 +1,98 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division +import logging +import numpy as np +import random +import torch +import torchvision.transforms as transforms +from pymic.io.nifty_dataset import NiftyDataset +from pymic.loss.seg.util import get_soft_label +from pymic.loss.seg.util import reshape_prediction_and_ground_truth +from pymic.loss.seg.util import get_classwise_dice +from pymic.loss.seg.ssl import EntropyLoss +from pymic.net_run.agent_seg import SegmentationAgent +from pymic.transform.trans_dict import TransformDict + +class SSLSegAgent(SegmentationAgent): + """ + Implementation of the following paper: + Yves Grandvalet and Yoshua Bengio, + Semi-supervised Learningby Entropy Minimization. + NeurIPS, 2005. + """ + def __init__(self, config, stage = 'train'): + super(SSLSegAgent, self).__init__(config, stage) + self.transform_dict = TransformDict + self.train_set_unlab = None + + def get_unlabeled_dataset_from_config(self): + root_dir = self.config['dataset']['root_dir'] + modal_num = self.config['dataset'].get('modal_num', 1) + transform_names = self.config['dataset']['train_transform_unlab'] + + self.transform_list = [] + if(transform_names is None or len(transform_names) == 0): + data_transform = None + else: + transform_param = self.config['dataset'] + transform_param['task'] = 'segmentation' + for name in transform_names: + if(name not in self.transform_dict): + raise(ValueError("Undefined transform {0:}".format(name))) + one_transform = self.transform_dict[name](transform_param) + self.transform_list.append(one_transform) + data_transform = transforms.Compose(self.transform_list) + + csv_file = self.config['dataset'].get('train_csv_unlab', None) + dataset = NiftyDataset(root_dir=root_dir, + csv_file = csv_file, + modal_num = modal_num, + with_label= False, + transform = data_transform ) + return dataset + + def create_dataset(self): + super(SSLSegAgent, self).create_dataset() + if(self.stage == 'train'): + if(self.train_set_unlab is None): + self.train_set_unlab = self.get_unlabeled_dataset_from_config() + if(self.deterministic): + def worker_init_fn(worker_id): + random.seed(self.random_seed+worker_id) + worker_init = worker_init_fn + else: + worker_init = None + + bn_train_unlab = self.config['dataset']['train_batch_size_unlab'] + num_worker = self.config['dataset'].get('num_workder', 16) + self.train_loader_unlab = torch.utils.data.DataLoader(self.train_set_unlab, + batch_size = bn_train_unlab, shuffle=True, num_workers= num_worker, + worker_init_fn=worker_init) + + def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): + loss_scalar ={'train':train_scalars['loss'], + 'valid':valid_scalars['loss']} + loss_sup_scalar = {'train':train_scalars['loss_sup']} + loss_upsup_scalar = {'train':train_scalars['loss_reg']} + dice_scalar ={'train':train_scalars['avg_dice'], 'valid':valid_scalars['avg_dice']} + self.summ_writer.add_scalars('loss', loss_scalar, glob_it) + self.summ_writer.add_scalars('loss_sup', loss_sup_scalar, glob_it) + self.summ_writer.add_scalars('loss_reg', loss_upsup_scalar, glob_it) + self.summ_writer.add_scalars('regular_w', {'regular_w':train_scalars['regular_w']}, glob_it) + self.summ_writer.add_scalars('lr', {"lr": lr_value}, glob_it) + self.summ_writer.add_scalars('dice', dice_scalar, glob_it) + class_num = self.config['network']['class_num'] + for c in range(class_num): + cls_dice_scalar = {'train':train_scalars['class_dice'][c], \ + 'valid':valid_scalars['class_dice'][c]} + self.summ_writer.add_scalars('class_{0:}_dice'.format(c), cls_dice_scalar, glob_it) + logging.info('train loss {0:.4f}, avg dice {1:.4f} '.format( + train_scalars['loss'], train_scalars['avg_dice']) + "[" + \ + ' '.join("{0:.4f}".format(x) for x in train_scalars['class_dice']) + "]") + logging.info('valid loss {0:.4f}, avg dice {1:.4f} '.format( + valid_scalars['loss'], valid_scalars['avg_dice']) + "[" + \ + ' '.join("{0:.4f}".format(x) for x in valid_scalars['class_dice']) + "]") + + def train_valid(self): + self.trainIter_unlab = iter(self.train_loader_unlab) + super(SSLSegAgent, self).train_valid() diff --git a/pymic/net_run_ssl/ssl_cct.py b/pymic/net_run_ssl/ssl_cct.py new file mode 100644 index 0000000..d0c4f24 --- /dev/null +++ b/pymic/net_run_ssl/ssl_cct.py @@ -0,0 +1,154 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division +import logging +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from torch.optim import lr_scheduler +from pymic.loss.seg.util import get_soft_label +from pymic.loss.seg.util import reshape_prediction_and_ground_truth +from pymic.loss.seg.util import get_classwise_dice +from pymic.net_run_ssl.ssl_abstract import SSLSegAgent +from pymic.util.ramps import get_rampup_ratio + +def softmax_mse_loss(inputs, targets, conf_mask=False, threshold=None, use_softmax=False): + assert inputs.requires_grad == True and targets.requires_grad == False + assert inputs.size() == targets.size() # (batch_size * num_classes * H * W) + inputs = F.softmax(inputs, dim=1) + if use_softmax: + targets = F.softmax(targets, dim=1) + + if conf_mask: + loss_mat = F.mse_loss(inputs, targets, reduction='none') + mask = (targets.max(1)[0] > threshold) + loss_mat = loss_mat[mask.unsqueeze(1).expand_as(loss_mat)] + if loss_mat.shape.numel() == 0: loss_mat = torch.tensor([0.]).to(inputs.device) + return loss_mat.mean() + else: + return F.mse_loss(inputs, targets, reduction='mean') # take the mean over the batch_size + + +def softmax_kl_loss(inputs, targets, conf_mask=False, threshold=None, use_softmax=False): + assert inputs.requires_grad == True and targets.requires_grad == False + assert inputs.size() == targets.size() + input_log_softmax = F.log_softmax(inputs, dim=1) + if use_softmax: + targets = F.softmax(targets, dim=1) + + if conf_mask: + loss_mat = F.kl_div(input_log_softmax, targets, reduction='none') + mask = (targets.max(1)[0] > threshold) + loss_mat = loss_mat[mask.unsqueeze(1).expand_as(loss_mat)] + if loss_mat.shape.numel() == 0: loss_mat = torch.tensor([0.]).to(inputs.device) + return loss_mat.sum() / mask.shape.numel() + else: + return F.kl_div(input_log_softmax, targets, reduction='mean') + + +def softmax_js_loss(inputs, targets, **_): + assert inputs.requires_grad == True and targets.requires_grad == False + assert inputs.size() == targets.size() + epsilon = 1e-5 + + M = (F.softmax(inputs, dim=1) + targets) * 0.5 + kl1 = F.kl_div(F.log_softmax(inputs, dim=1), M, reduction='mean') + kl2 = F.kl_div(torch.log(targets+epsilon), M, reduction='mean') + return (kl1 + kl2) * 0.5 + +unsup_loss_dict = {"MSE": softmax_mse_loss, + "KL":softmax_kl_loss, + "JS":softmax_js_loss} + +class SSLCCT(SSLSegAgent): + """ + Cross-Consistency Training according to the following paper: + Yassine Ouali, Celine Hudelot and Myriam Tami: + Semi-Supervised Semantic Segmentation With Cross-Consistency Training. + CVPR 2020. + https://arxiv.org/abs/2003.09005 + Code adapted from: https://github.com/yassouali/CCT + """ + def training(self): + class_num = self.config['network']['class_num'] + iter_valid = self.config['training']['iter_valid'] + ssl_cfg = self.config['semi_supervised_learning'] + iter_max = self.config['training']['iter_max'] + rampup_start = ssl_cfg.get('rampup_start', 0) + rampup_end = ssl_cfg.get('rampup_end', iter_max) + unsup_loss_name = ssl_cfg.get('unsupervised_loss', "MSE") + self.unsup_loss_f = unsup_loss_dict[unsup_loss_name] + train_loss = 0 + train_loss_sup = 0 + train_loss_reg = 0 + train_dice_list = [] + self.net.train() + + for it in range(iter_valid): + try: + data_lab = next(self.trainIter) + except StopIteration: + self.trainIter = iter(self.train_loader) + data_lab = next(self.trainIter) + try: + data_unlab = next(self.trainIter_unlab) + except StopIteration: + self.trainIter_unlab = iter(self.train_loader_unlab) + data_unlab = next(self.trainIter_unlab) + + # get the inputs + x0 = self.convert_tensor_type(data_lab['image']) + y0 = self.convert_tensor_type(data_lab['label_prob']) + x1 = self.convert_tensor_type(data_unlab['image']) + inputs = torch.cat([x0, x1], dim = 0) + inputs, y0 = inputs.to(self.device), y0.to(self.device) + + # zero the parameter gradients + self.optimizer.zero_grad() + + # forward pass + output, aux_outputs = self.net(inputs) + n0 = list(x0.shape)[0] + + # get supervised loss + p0 = output[:n0] + loss_sup = self.get_loss_value(data_lab, p0, y0) + + # get regularization loss + p1 = F.softmax(output[n0:].detach(), dim=1) + p1_aux = [aux_out[n0:] for aux_out in aux_outputs] + loss_reg = 0.0 + for p1_auxi in p1_aux: + loss_reg += self.unsup_loss_f( p1_auxi, p1, use_softmax = True) + loss_reg = loss_reg / len(p1_aux) + + rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") + regular_w = ssl_cfg.get('regularize_w', 0.1) * rampup_ratio + loss = loss_sup + regular_w*loss_reg + + loss.backward() + self.optimizer.step() + if(self.scheduler is not None and \ + not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): + self.scheduler.step() + train_loss = train_loss + loss.item() + train_loss_sup = train_loss_sup + loss_sup.item() + train_loss_reg = train_loss_reg + loss_reg.item() + # get dice evaluation for each class in annotated images + if(isinstance(p0, tuple) or isinstance(p0, list)): + p0 = p0[0] + p0_argmax = torch.argmax(p0, dim = 1, keepdim = True) + p0_soft = get_soft_label(p0_argmax, class_num, self.tensor_type) + p0_soft, y0 = reshape_prediction_and_ground_truth(p0_soft, y0) + dice_list = get_classwise_dice(p0_soft, y0) + train_dice_list.append(dice_list.cpu().numpy()) + train_avg_loss = train_loss / iter_valid + train_avg_loss_sup = train_loss_sup / iter_valid + train_avg_loss_reg = train_loss_reg / iter_valid + train_cls_dice = np.asarray(train_dice_list).mean(axis = 0) + train_avg_dice = train_cls_dice.mean() + + train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, + 'loss_reg':train_avg_loss_reg, 'regular_w':regular_w, + 'avg_dice':train_avg_dice, 'class_dice': train_cls_dice} + return train_scalers diff --git a/pymic/net_run_ssl/ssl_cps.py b/pymic/net_run_ssl/ssl_cps.py index e9bc41b..2264d0d 100644 --- a/pymic/net_run_ssl/ssl_cps.py +++ b/pymic/net_run_ssl/ssl_cps.py @@ -3,16 +3,32 @@ import logging import numpy as np import torch -import torch.optim as optim +import torch.nn as nn +from torch.optim import lr_scheduler from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice -from pymic.util.ramps import sigmoid_rampup -from pymic.net_run.get_optimizer import get_optimiser -from pymic.net_run_ssl.ssl_em import SSLEntropyMinimization +from pymic.net_run_ssl.ssl_abstract import SSLSegAgent from pymic.net.net_dict_seg import SegNetDict +from pymic.util.ramps import get_rampup_ratio -class SSLCrossPseudoSupervision(SSLEntropyMinimization): +class BiNet(nn.Module): + def __init__(self, params): + super(BiNet, self).__init__() + net_name = params['net_type'] + self.net1 = SegNetDict[net_name](params) + self.net2 = SegNetDict[net_name](params) + + def forward(self, x): + out1 = self.net1(x) + out2 = self.net2(x) + + if(self.training): + return out1, out2 + else: + return (out1 + out2) / 3 + +class SSLCPS(SSLSegAgent): """ Using cross pseudo supervision according to the following paper: Xiaokang Chen, Yuhui Yuan, Gang Zeng, Jingdong Wang, @@ -21,51 +37,28 @@ class SSLCrossPseudoSupervision(SSLEntropyMinimization): https://arxiv.org/abs/2106.01226 """ def __init__(self, config, stage = 'train'): - super(SSLCrossPseudoSupervision, self).__init__(config, stage) - self.net2 = None - self.optimizer2 = None - self.scheduler2 = None + super(SSLCPS, self).__init__(config, stage) def create_network(self): - super(SSLCrossPseudoSupervision, self).create_network() - if(self.net2 is None): - net_name = self.config['network']['net_type'] - if(net_name not in SegNetDict): - raise ValueError("Undefined network {0:}".format(net_name)) - self.net2 = SegNetDict[net_name](self.config['network']) + if(self.net is None): + self.net = BiNet(self.config['network']) if(self.tensor_type == 'float'): - self.net2.float() + self.net.float() else: - self.net2.double() - - def train_valid(self): - # create optimizor for the second network - if(self.optimizer2 is None): - self.optimizer2 = get_optimiser(self.config['training']['optimizer'], - self.net2.parameters(), - self.config['training']) - last_iter = -1 - # if(self.checkpoint is not None): - # self.optimizer2.load_state_dict(self.checkpoint['optimizer_state_dict']) - # last_iter = self.checkpoint['iteration'] - 1 - if(self.scheduler2 is None): - self.scheduler2 = optim.lr_scheduler.MultiStepLR(self.optimizer2, - self.config['training']['lr_milestones'], - self.config['training']['lr_gamma'], - last_epoch = last_iter) - super(SSLCrossPseudoSupervision, self).train_valid() + self.net.double() def training(self): class_num = self.config['network']['class_num'] iter_valid = self.config['training']['iter_valid'] ssl_cfg = self.config['semi_supervised_learning'] + iter_max = self.config['training']['iter_max'] + rampup_start = ssl_cfg.get('rampup_start', 0) + rampup_end = ssl_cfg.get('rampup_end', iter_max) train_loss = 0 train_loss_sup1, train_loss_pseudo_sup1 = 0, 0 train_loss_sup2, train_loss_pseudo_sup2 = 0, 0 train_dice_list = [] self.net.train() - self.net2.train() - self.net2.to(self.device) for it in range(iter_valid): try: data_lab = next(self.trainIter) @@ -87,9 +80,8 @@ def training(self): # zero the parameter gradients self.optimizer.zero_grad() - self.optimizer2.zero_grad() - outputs1, outputs2 = self.net(inputs), self.net2(inputs) + outputs1, outputs2 = self.net(inputs) outputs_soft1 = torch.softmax(outputs1, dim=1) outputs_soft2 = torch.softmax(outputs2, dim=1) @@ -107,13 +99,8 @@ def training(self): pse_sup1 = self.get_loss_value(data_unlab, outputs1[n0:], pse_prob2) pse_sup2 = self.get_loss_value(data_unlab, outputs2[n0:], pse_prob1) - iter_max = self.config['training']['iter_max'] - ramp_up_len = ssl_cfg.get('ramp_up_length', iter_max) - regular_w = 0.0 - if(self.glob_it > ssl_cfg.get('iter_sup', 0)): - regular_w = ssl_cfg.get('regularize_w', 0.1) - if(ramp_up_len is not None and self.glob_it < ramp_up_len): - regular_w = regular_w * sigmoid_rampup(self.glob_it, ramp_up_len) + rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") + regular_w = ssl_cfg.get('regularize_w', 0.1) * rampup_ratio model1_loss = loss_sup1 + regular_w * pse_sup1 model2_loss = loss_sup2 + regular_w * pse_sup2 @@ -121,9 +108,9 @@ def training(self): loss.backward() self.optimizer.step() - self.scheduler.step() - self.optimizer2.step() - self.scheduler2.step() + if(self.scheduler is not None and \ + not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): + self.scheduler.step() train_loss = train_loss + loss.item() train_loss_sup1 = train_loss_sup1 + loss_sup1.item() @@ -152,8 +139,8 @@ def training(self): 'loss_pse_sup1':train_avg_loss_pse_sup1, 'loss_pse_sup2': train_avg_loss_pse_sup2, 'regular_w':regular_w, 'avg_dice':train_avg_dice, 'class_dice': train_cls_dice} return train_scalers - - def write_scalars(self, train_scalars, valid_scalars, glob_it): + + def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): loss_scalar ={'train':train_scalars['loss'], 'valid':valid_scalars['loss']} loss_sup_scalar = {'net1':train_scalars['loss_sup1'], @@ -165,6 +152,7 @@ def write_scalars(self, train_scalars, valid_scalars, glob_it): self.summ_writer.add_scalars('loss_sup', loss_sup_scalar, glob_it) self.summ_writer.add_scalars('loss_pseudo_sup', loss_pse_sup_scalar, glob_it) self.summ_writer.add_scalars('regular_w', {'regular_w':train_scalars['regular_w']}, glob_it) + self.summ_writer.add_scalars('lr', {"lr": lr_value}, glob_it) self.summ_writer.add_scalars('dice', dice_scalar, glob_it) class_num = self.config['network']['class_num'] for c in range(class_num): diff --git a/pymic/net_run_ssl/ssl_em.py b/pymic/net_run_ssl/ssl_em.py index 32f85c1..49dd22f 100644 --- a/pymic/net_run_ssl/ssl_em.py +++ b/pymic/net_run_ssl/ssl_em.py @@ -2,78 +2,36 @@ from __future__ import print_function, division import logging import numpy as np -import random import torch -import torchvision.transforms as transforms -from pymic.io.nifty_dataset import NiftyDataset +from torch.optim import lr_scheduler from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.loss.seg.ssl import EntropyLoss -from pymic.net_run.agent_seg import SegmentationAgent +from pymic.net_run_ssl.ssl_abstract import SSLSegAgent from pymic.transform.trans_dict import TransformDict -from pymic.util.ramps import sigmoid_rampup +from pymic.util.ramps import get_rampup_ratio -class SSLEntropyMinimization(SegmentationAgent): +class SSLEntropyMinimization(SSLSegAgent): """ Implementation of the following paper: - Yves Grandvalet and Yoshua Bengio, + Yves Grandvalet and Yoshua Bengio: Semi-supervised Learningby Entropy Minimization. - NeurIPS, 2005. + NeurIPS, 2005. + https://papers.nips.cc/paper/2004/file/96f2b50b5d3613adf9c27049b2a888c7-Paper.pdf """ def __init__(self, config, stage = 'train'): super(SSLEntropyMinimization, self).__init__(config, stage) self.transform_dict = TransformDict self.train_set_unlab = None - def get_unlabeled_dataset_from_config(self): - root_dir = self.config['dataset']['root_dir'] - modal_num = self.config['dataset']['modal_num'] - transform_names = self.config['dataset']['train_transform_unlab'] - - self.transform_list = [] - if(transform_names is None or len(transform_names) == 0): - data_transform = None - else: - transform_param = self.config['dataset'] - transform_param['task'] = 'segmentation' - for name in transform_names: - if(name not in self.transform_dict): - raise(ValueError("Undefined transform {0:}".format(name))) - one_transform = self.transform_dict[name](transform_param) - self.transform_list.append(one_transform) - data_transform = transforms.Compose(self.transform_list) - - csv_file = self.config['dataset'].get('train_csv_unlab', None) - dataset = NiftyDataset(root_dir=root_dir, - csv_file = csv_file, - modal_num = modal_num, - with_label= False, - transform = data_transform ) - return dataset - - def create_dataset(self): - super(SSLEntropyMinimization, self).create_dataset() - if(self.stage == 'train'): - if(self.train_set_unlab is None): - self.train_set_unlab = self.get_unlabeled_dataset_from_config() - if(self.deterministic): - def worker_init_fn(worker_id): - random.seed(self.random_seed+worker_id) - worker_init = worker_init_fn - else: - worker_init = None - - bn_train_unlab = self.config['dataset']['train_batch_size_unlab'] - num_worker = self.config['dataset'].get('num_workder', 16) - self.train_loader_unlab = torch.utils.data.DataLoader(self.train_set_unlab, - batch_size = bn_train_unlab, shuffle=True, num_workers= num_worker, - worker_init_fn=worker_init) - def training(self): class_num = self.config['network']['class_num'] iter_valid = self.config['training']['iter_valid'] ssl_cfg = self.config['semi_supervised_learning'] + iter_max = self.config['training']['iter_max'] + rampup_start = ssl_cfg.get('rampup_start', 0) + rampup_end = ssl_cfg.get('rampup_end', iter_max) train_loss = 0 train_loss_sup = 0 train_loss_reg = 0 @@ -109,18 +67,16 @@ def training(self): loss_dict = {"prediction":outputs, 'softmax':True} loss_reg = EntropyLoss()(loss_dict) - iter_max = self.config['training']['iter_max'] - ramp_up_length = ssl_cfg.get('ramp_up_length', iter_max) - regular_w = 0.0 - if(self.glob_it > ssl_cfg.get('iter_sup', 0)): - regular_w = ssl_cfg.get('regularize_w', 0.1) - if(ramp_up_length is not None and self.glob_it < ramp_up_length): - regular_w = regular_w * sigmoid_rampup(self.glob_it, ramp_up_length) + rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") + regular_w = ssl_cfg.get('regularize_w', 0.1) * rampup_ratio + loss = loss_sup + regular_w*loss_reg # if (self.config['training']['use']) loss.backward() self.optimizer.step() - self.scheduler.step() + if(self.scheduler is not None and \ + not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): + self.scheduler.step() train_loss = train_loss + loss.item() train_loss_sup = train_loss_sup + loss_sup.item() @@ -142,31 +98,4 @@ def training(self): train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, 'loss_reg':train_avg_loss_reg, 'regular_w':regular_w, 'avg_dice':train_avg_dice, 'class_dice': train_cls_dice} - return train_scalers - - def write_scalars(self, train_scalars, valid_scalars, glob_it): - loss_scalar ={'train':train_scalars['loss'], - 'valid':valid_scalars['loss']} - loss_sup_scalar = {'train':train_scalars['loss_sup']} - loss_upsup_scalar = {'train':train_scalars['loss_reg']} - dice_scalar ={'train':train_scalars['avg_dice'], 'valid':valid_scalars['avg_dice']} - self.summ_writer.add_scalars('loss', loss_scalar, glob_it) - self.summ_writer.add_scalars('loss_sup', loss_sup_scalar, glob_it) - self.summ_writer.add_scalars('loss_reg', loss_upsup_scalar, glob_it) - self.summ_writer.add_scalars('regular_w', {'regular_w':train_scalars['regular_w']}, glob_it) - self.summ_writer.add_scalars('dice', dice_scalar, glob_it) - class_num = self.config['network']['class_num'] - for c in range(class_num): - cls_dice_scalar = {'train':train_scalars['class_dice'][c], \ - 'valid':valid_scalars['class_dice'][c]} - self.summ_writer.add_scalars('class_{0:}_dice'.format(c), cls_dice_scalar, glob_it) - logging.info('train loss {0:.4f}, avg dice {1:.4f} '.format( - train_scalars['loss'], train_scalars['avg_dice']) + "[" + \ - ' '.join("{0:.4f}".format(x) for x in train_scalars['class_dice']) + "]") - logging.info('valid loss {0:.4f}, avg dice {1:.4f} '.format( - valid_scalars['loss'], valid_scalars['avg_dice']) + "[" + \ - ' '.join("{0:.4f}".format(x) for x in valid_scalars['class_dice']) + "]") - - def train_valid(self): - self.trainIter_unlab = iter(self.train_loader_unlab) - super(SSLEntropyMinimization, self).train_valid() + return train_scalers \ No newline at end of file diff --git a/pymic/net_run_ssl/ssl_main.py b/pymic/net_run_ssl/ssl_main.py index cf5a8cd..d904ab1 100644 --- a/pymic/net_run_ssl/ssl_main.py +++ b/pymic/net_run_ssl/ssl_main.py @@ -8,14 +8,17 @@ from pymic.net_run_ssl.ssl_em import SSLEntropyMinimization from pymic.net_run_ssl.ssl_mt import SSLMeanTeacher from pymic.net_run_ssl.ssl_uamt import SSLUncertaintyAwareMeanTeacher +from pymic.net_run_ssl.ssl_cct import SSLCCT +from pymic.net_run_ssl.ssl_cps import SSLCPS from pymic.net_run_ssl.ssl_urpc import SSLURPC -from pymic.net_run_ssl.ssl_cps import SSLCrossPseudoSupervision + SSLMethodDict = {'EntropyMinimization': SSLEntropyMinimization, 'MeanTeacher': SSLMeanTeacher, 'UAMT': SSLUncertaintyAwareMeanTeacher, - 'URPC': SSLURPC, - 'CPS': SSLCrossPseudoSupervision} + 'CCT': SSLCCT, + 'CPS': SSLCPS, + 'URPC': SSLURPC} def main(): if(len(sys.argv) < 3): @@ -29,7 +32,7 @@ def main(): log_dir = config['training']['ckpt_save_dir'] if(not os.path.exists(log_dir)): os.mkdir(log_dir) - logging.basicConfig(filename=log_dir+"/log.txt", level=logging.INFO, + logging.basicConfig(filename=log_dir+"/log_{0:}.txt".format(stage), level=logging.INFO, format='%(message)s') logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) logging_config(config) diff --git a/pymic/net_run_ssl/ssl_mt.py b/pymic/net_run_ssl/ssl_mt.py index d25edbc..0456726 100644 --- a/pymic/net_run_ssl/ssl_mt.py +++ b/pymic/net_run_ssl/ssl_mt.py @@ -3,16 +3,21 @@ import logging import torch import numpy as np +from torch.optim import lr_scheduler from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice -from pymic.util.ramps import sigmoid_rampup -from pymic.net_run_ssl.ssl_em import SSLEntropyMinimization +from pymic.net_run_ssl.ssl_abstract import SSLSegAgent from pymic.net.net_dict_seg import SegNetDict +from pymic.util.ramps import get_rampup_ratio -class SSLMeanTeacher(SSLEntropyMinimization): +class SSLMeanTeacher(SSLSegAgent): """ - Training and testing agent for semi-supervised segmentation + Mean Teacher for semi-supervised learning according to the following paper: + Antti Tarvainen, Harri Valpola: Mean teachers are better role models: Weight-averaged + consistency targets improve semi-supervised deep learning results. + NeurIPS 2017. + https://arxiv.org/abs/1703.01780 """ def __init__(self, config, stage = 'train'): super(SSLMeanTeacher, self).__init__(config, stage) @@ -34,6 +39,9 @@ def training(self): class_num = self.config['network']['class_num'] iter_valid = self.config['training']['iter_valid'] ssl_cfg = self.config['semi_supervised_learning'] + iter_max = self.config['training']['iter_max'] + rampup_start = ssl_cfg.get('rampup_start', 0) + rampup_end = ssl_cfg.get('rampup_end', iter_max) train_loss = 0 train_loss_sup = 0 train_loss_reg = 0 @@ -77,19 +85,17 @@ def training(self): outputs_ema = self.net_ema(inputs_ema) p1_ema_soft = torch.softmax(outputs_ema, dim=1) - iter_max = self.config['training']['iter_max'] - ramp_up_length = ssl_cfg.get('ramp_up_length', iter_max) - regular_w = 0.0 - if(self.glob_it > ssl_cfg.get('iter_sup', 0)): - regular_w = ssl_cfg.get('regularize_w', 0.1) - if(ramp_up_length is not None and self.glob_it < ramp_up_length): - regular_w = regular_w * sigmoid_rampup(self.glob_it, ramp_up_length) + rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") + regular_w = ssl_cfg.get('regularize_w', 0.1) * rampup_ratio + loss_reg = torch.nn.MSELoss()(p1_soft, p1_ema_soft) loss = loss_sup + regular_w*loss_reg loss.backward() self.optimizer.step() - self.scheduler.step() + if(self.scheduler is not None and \ + not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): + self.scheduler.step() # update EMA alpha = ssl_cfg.get('ema_decay', 0.99) diff --git a/pymic/net_run_ssl/ssl_uamt.py b/pymic/net_run_ssl/ssl_uamt.py index d1de32f..360dab1 100644 --- a/pymic/net_run_ssl/ssl_uamt.py +++ b/pymic/net_run_ssl/ssl_uamt.py @@ -3,11 +3,12 @@ import logging import torch import numpy as np +from torch.optim import lr_scheduler from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice -from pymic.util.ramps import sigmoid_rampup from pymic.net_run_ssl.ssl_mt import SSLMeanTeacher +from pymic.util.ramps import get_rampup_ratio class SSLUncertaintyAwareMeanTeacher(SSLMeanTeacher): """ @@ -21,6 +22,9 @@ def training(self): class_num = self.config['network']['class_num'] iter_valid = self.config['training']['iter_valid'] ssl_cfg = self.config['semi_supervised_learning'] + iter_max = self.config['training']['iter_max'] + rampup_start = ssl_cfg.get('rampup_start', 0) + rampup_end = ssl_cfg.get('rampup_end', iter_max) train_loss = 0 train_loss_sup = 0 train_loss_reg = 0 @@ -80,24 +84,21 @@ def training(self): uncertainty = -1.0 * torch.sum(preds*torch.log(preds + 1e-6), dim=1, keepdim=True) - iter_max = self.config['training']['iter_max'] - ramp_up_length = ssl_cfg.get('ramp_up_length', iter_max) - threshold_ramp = sigmoid_rampup(self.glob_it, iter_max) + rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") class_num = list(y0.shape)[1] - threshold = (0.75+0.25*threshold_ramp)*np.log(class_num) + threshold = (0.75+0.25*rampup_ratio)*np.log(class_num) mask = (uncertainty < threshold).float() loss_reg = torch.sum(mask*square_error)/(2*torch.sum(mask)+1e-16) - regular_w = 0.0 - if(self.glob_it > ssl_cfg.get('iter_sup', 0)): - regular_w = ssl_cfg.get('regularize_w', 0.1) - if(ramp_up_length is not None and self.glob_it < ramp_up_length): - regular_w = regular_w * sigmoid_rampup(self.glob_it, ramp_up_length) + regular_w = ssl_cfg.get('regularize_w', 0.1) * rampup_ratio loss = loss_sup + regular_w*loss_reg loss.backward() self.optimizer.step() - self.scheduler.step() + if(self.scheduler is not None and \ + not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): + self.scheduler.step() + # update EMA alpha = ssl_cfg.get('ema_decay', 0.99) diff --git a/pymic/net_run_ssl/ssl_urpc.py b/pymic/net_run_ssl/ssl_urpc.py index d2d5a1f..20b3d84 100644 --- a/pymic/net_run_ssl/ssl_urpc.py +++ b/pymic/net_run_ssl/ssl_urpc.py @@ -4,26 +4,29 @@ import torch import torch.nn as nn import numpy as np +from torch.optim import lr_scheduler from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice -from pymic.util.ramps import sigmoid_rampup -from pymic.net_run_ssl.ssl_em import SSLEntropyMinimization +from pymic.net_run_ssl.ssl_abstract import SSLSegAgent +from pymic.util.ramps import get_rampup_ratio -class SSLURPC(SSLEntropyMinimization): +class SSLURPC(SSLSegAgent): """ Uncertainty-Rectified Pyramid Consistency according to the following paper: - Xiangde Luo, Wenjun Liao, Jieneng Chen, Tao Song, Yinan Chen, - Shichuan Zhang, Nianyong Chen, Guotai Wang, Shaoting Zhang. - Efficient Semi-supervised Gross Target Volume of Nasopharyngeal Carcinoma - Segmentation via Uncertainty Rectified Pyramid Consistency. - MICCAI 2021, pp. 318-329. - https://arxiv.org/abs/2012.07042 + Xiangde Luo, Guotai Wang*, Wenjun Liao, Jieneng Chen, Tao Song, Yinan Chen, + Shichuan Zhang, Dimitris N. Metaxas, Shaoting Zhang. + Semi-Supervised Medical Image Segmentation via Uncertainty Rectified Pyramid Consistency . + Medical Image Analysis 2022. + https://doi.org/10.1016/j.media.2022.102517 """ def training(self): class_num = self.config['network']['class_num'] iter_valid = self.config['training']['iter_valid'] ssl_cfg = self.config['semi_supervised_learning'] + iter_max = self.config['training']['iter_max'] + rampup_start = ssl_cfg.get('rampup_start', 0) + rampup_end = ssl_cfg.get('rampup_end', iter_max) train_loss = 0 train_loss_sup = 0 train_loss_reg = 0 @@ -78,20 +81,15 @@ def training(self): loss_reg += loss_i loss_reg = loss_reg / len(outputs_list) - iter_max = self.config['training']['iter_max'] - ramp_up_length = ssl_cfg.get('ramp_up_length', iter_max) - regular_w = 0.0 - if(self.glob_it > ssl_cfg.get('iter_sup', 0)): - regular_w = ssl_cfg.get('regularize_w', 0.1) - if(ramp_up_length is not None and self.glob_it < ramp_up_length): - regular_w = regular_w * sigmoid_rampup(self.glob_it, ramp_up_length) - + rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") + regular_w = ssl_cfg.get('regularize_w', 0.1) * rampup_ratio loss = loss_sup + regular_w*loss_reg loss.backward() self.optimizer.step() - self.scheduler.step() - + if(self.scheduler is not None and \ + not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): + self.scheduler.step() train_loss = train_loss + loss.item() train_loss_sup = train_loss_sup + loss_sup.item() train_loss_reg = train_loss_reg + loss_reg.item() diff --git a/pymic/net_run_wsl/__init__.py b/pymic/net_run_wsl/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pymic/net_run_wsl/wsl_abstract.py b/pymic/net_run_wsl/wsl_abstract.py new file mode 100644 index 0000000..d64063e --- /dev/null +++ b/pymic/net_run_wsl/wsl_abstract.py @@ -0,0 +1,38 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division +import logging +from pymic.net_run.agent_seg import SegmentationAgent + +class WSLSegAgent(SegmentationAgent): + """ + Training and testing agent for weakly supervised segmentation + """ + def __init__(self, config, stage = 'train'): + super(WSLSegAgent, self).__init__(config, stage) + + def training(self): + pass + + def write_scalars(self, train_scalars, valid_scalars, lr_value, glob_it): + loss_scalar ={'train':train_scalars['loss'], + 'valid':valid_scalars['loss']} + loss_sup_scalar = {'train':train_scalars['loss_sup']} + loss_upsup_scalar = {'train':train_scalars['loss_reg']} + dice_scalar ={'train':train_scalars['avg_dice'], 'valid':valid_scalars['avg_dice']} + self.summ_writer.add_scalars('loss', loss_scalar, glob_it) + self.summ_writer.add_scalars('loss_sup', loss_sup_scalar, glob_it) + self.summ_writer.add_scalars('loss_reg', loss_upsup_scalar, glob_it) + self.summ_writer.add_scalars('regular_w', {'regular_w':train_scalars['regular_w']}, glob_it) + self.summ_writer.add_scalars('lr', {"lr": lr_value}, glob_it) + self.summ_writer.add_scalars('dice', dice_scalar, glob_it) + class_num = self.config['network']['class_num'] + for c in range(class_num): + cls_dice_scalar = {'train':train_scalars['class_dice'][c], \ + 'valid':valid_scalars['class_dice'][c]} + self.summ_writer.add_scalars('class_{0:}_dice'.format(c), cls_dice_scalar, glob_it) + logging.info('train loss {0:.4f}, avg dice {1:.4f} '.format( + train_scalars['loss'], train_scalars['avg_dice']) + "[" + \ + ' '.join("{0:.4f}".format(x) for x in train_scalars['class_dice']) + "]") + logging.info('valid loss {0:.4f}, avg dice {1:.4f} '.format( + valid_scalars['loss'], valid_scalars['avg_dice']) + "[" + \ + ' '.join("{0:.4f}".format(x) for x in valid_scalars['class_dice']) + "]") diff --git a/pymic/net_run_wsl/wsl_dmpls.py b/pymic/net_run_wsl/wsl_dmpls.py index c42ed7a..8ee9e53 100644 --- a/pymic/net_run_wsl/wsl_dmpls.py +++ b/pymic/net_run_wsl/wsl_dmpls.py @@ -4,36 +4,37 @@ import numpy as np import random import torch -import torchvision.transforms as transforms -from pymic.io.nifty_dataset import NiftyDataset +from torch.optim import lr_scheduler from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.loss.seg.dice import DiceLoss -from pymic.loss.seg.ssl import TotalVariationLoss -from pymic.net_run.agent_seg import SegmentationAgent -from pymic.net_run_wsl.wsl_em import WSL_EntropyMinimization -from pymic.util.ramps import sigmoid_rampup +from pymic.net_run_wsl.wsl_abstract import WSLSegAgent +from pymic.util.ramps import get_rampup_ratio -class WSL_DMPLS(WSL_EntropyMinimization): +class WSLDMPLS(WSLSegAgent): """ Implementation of the following paper: Xiangde Luo, Minhao Hu, Wenjun Liao, Shuwei Zhai, Tao Song, Guotai Wang, Shaoting Zhang. ScribblScribble-Supervised Medical Image Segmentation via Dual-Branch Network and Dynamically Mixed Pseudo Labels Supervision. MICCAI 2022. + https://arxiv.org/abs/2203.02106 """ def __init__(self, config, stage = 'train'): net_type = config['network']['net_type'] - if net_type not in ['DualBranchUNet2D', 'DualBranchUNet3D']: + if net_type not in ['UNet2D_DualBranch', 'UNet3D_DualBranch']: raise ValueError("""For WSL_DMPLS, a dual branch network is expected. \ - It only supports DualBranchUNet2D and DualBranchUNet3D currently.""") - super(WSL_DMPLS, self).__init__(config, stage) + It only supports UNet2D_DualBranch and UNet3D_DualBranch currently.""") + super(WSLDMPLS, self).__init__(config, stage) def training(self): class_num = self.config['network']['class_num'] iter_valid = self.config['training']['iter_valid'] wsl_cfg = self.config['weakly_supervised_learning'] + iter_max = self.config['training']['iter_max'] + rampup_start = wsl_cfg.get('rampup_start', 0) + rampup_end = wsl_cfg.get('rampup_end', iter_max) train_loss = 0 train_loss_sup = 0 train_loss_reg = 0 @@ -75,18 +76,15 @@ def training(self): loss_dict2 = {"prediction":outputs2, 'ground_truth':pseudo_lab} loss_reg = 0.5 * (loss_calculator(loss_dict1) + loss_calculator(loss_dict2)) - iter_max = self.config['training']['iter_max'] - ramp_up_length = wsl_cfg.get('ramp_up_length', iter_max) - regular_w = 0.0 - if(self.glob_it > wsl_cfg.get('iter_sup', 0)): - regular_w = wsl_cfg.get('regularize_w', 0.1) - if(ramp_up_length is not None and self.glob_it < ramp_up_length): - regular_w = regular_w * sigmoid_rampup(self.glob_it, ramp_up_length) + rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") + regular_w = wsl_cfg.get('regularize_w', 0.1) * rampup_ratio loss = loss_sup + regular_w*loss_reg loss.backward() self.optimizer.step() - self.scheduler.step() + if(self.scheduler is not None and \ + not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): + self.scheduler.step() train_loss = train_loss + loss.item() train_loss_sup = train_loss_sup + loss_sup.item() diff --git a/pymic/net_run_wsl/wsl_em.py b/pymic/net_run_wsl/wsl_em.py index 9534504..3b2d595 100644 --- a/pymic/net_run_wsl/wsl_em.py +++ b/pymic/net_run_wsl/wsl_em.py @@ -2,29 +2,30 @@ from __future__ import print_function, division import logging import numpy as np -import random import torch -import torchvision.transforms as transforms -from pymic.io.nifty_dataset import NiftyDataset +from torch.optim import lr_scheduler from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.loss.seg.ssl import EntropyLoss from pymic.net_run.agent_seg import SegmentationAgent -from pymic.transform.trans_dict import TransformDict -from pymic.util.ramps import sigmoid_rampup +from pymic.net_run_wsl.wsl_abstract import WSLSegAgent +from pymic.util.ramps import get_rampup_ratio -class WSL_EntropyMinimization(SegmentationAgent): +class WSLEntropyMinimization(WSLSegAgent): """ - Training and testing agent for semi-supervised segmentation + Weakly suepervised segmentation with Entropy Minimization Regularization. """ def __init__(self, config, stage = 'train'): - super(WSL_EntropyMinimization, self).__init__(config, stage) + super(WSLEntropyMinimization, self).__init__(config, stage) def training(self): class_num = self.config['network']['class_num'] iter_valid = self.config['training']['iter_valid'] wsl_cfg = self.config['weakly_supervised_learning'] + iter_max = self.config['training']['iter_max'] + rampup_start = wsl_cfg.get('rampup_start', 0) + rampup_end = wsl_cfg.get('rampup_end', iter_max) train_loss = 0 train_loss_sup = 0 train_loss_reg = 0 @@ -52,18 +53,15 @@ def training(self): loss_dict= {"prediction":outputs, 'softmax':True} loss_reg = EntropyLoss()(loss_dict) - iter_max = self.config['training']['iter_max'] - ramp_up_length = wsl_cfg.get('ramp_up_length', iter_max) - regular_w = 0.0 - if(self.glob_it > wsl_cfg.get('iter_sup', 0)): - regular_w = wsl_cfg.get('regularize_w', 0.1) - if(ramp_up_length is not None and self.glob_it < ramp_up_length): - regular_w = regular_w * sigmoid_rampup(self.glob_it, ramp_up_length) + rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") + regular_w = wsl_cfg.get('regularize_w', 0.1) * rampup_ratio loss = loss_sup + regular_w*loss_reg loss.backward() self.optimizer.step() - self.scheduler.step() + if(self.scheduler is not None and \ + not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): + self.scheduler.step() train_loss = train_loss + loss.item() train_loss_sup = train_loss_sup + loss_sup.item() @@ -85,27 +83,4 @@ def training(self): train_scalers = {'loss': train_avg_loss, 'loss_sup':train_avg_loss_sup, 'loss_reg':train_avg_loss_reg, 'regular_w':regular_w, 'avg_dice':train_avg_dice, 'class_dice': train_cls_dice} - return train_scalers - - def write_scalars(self, train_scalars, valid_scalars, glob_it): - loss_scalar ={'train':train_scalars['loss'], - 'valid':valid_scalars['loss']} - loss_sup_scalar = {'train':train_scalars['loss_sup']} - loss_upsup_scalar = {'train':train_scalars['loss_reg']} - dice_scalar ={'train':train_scalars['avg_dice'], 'valid':valid_scalars['avg_dice']} - self.summ_writer.add_scalars('loss', loss_scalar, glob_it) - self.summ_writer.add_scalars('loss_sup', loss_sup_scalar, glob_it) - self.summ_writer.add_scalars('loss_reg', loss_upsup_scalar, glob_it) - self.summ_writer.add_scalars('regular_w', {'regular_w':train_scalars['regular_w']}, glob_it) - self.summ_writer.add_scalars('dice', dice_scalar, glob_it) - class_num = self.config['network']['class_num'] - for c in range(class_num): - cls_dice_scalar = {'train':train_scalars['class_dice'][c], \ - 'valid':valid_scalars['class_dice'][c]} - self.summ_writer.add_scalars('class_{0:}_dice'.format(c), cls_dice_scalar, glob_it) - logging.info('train loss {0:.4f}, avg dice {1:.4f} '.format( - train_scalars['loss'], train_scalars['avg_dice']) + "[" + \ - ' '.join("{0:.4f}".format(x) for x in train_scalars['class_dice']) + "]") - logging.info('valid loss {0:.4f}, avg dice {1:.4f} '.format( - valid_scalars['loss'], valid_scalars['avg_dice']) + "[" + \ - ' '.join("{0:.4f}".format(x) for x in valid_scalars['class_dice']) + "]") + return train_scalers \ No newline at end of file diff --git a/pymic/net_run_wsl/wsl_gatedcrf.py b/pymic/net_run_wsl/wsl_gatedcrf.py index 8e9c6de..64e0f1b 100644 --- a/pymic/net_run_wsl/wsl_gatedcrf.py +++ b/pymic/net_run_wsl/wsl_gatedcrf.py @@ -2,24 +2,26 @@ from __future__ import print_function, division import logging import numpy as np -import random import torch -import torchvision.transforms as transforms -from pymic.io.nifty_dataset import NiftyDataset +from torch.optim import lr_scheduler from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.loss.seg.gatedcrf import ModelLossSemsegGatedCRF -from pymic.net_run.agent_seg import SegmentationAgent -from pymic.net_run_wsl.wsl_em import WSL_EntropyMinimization -from pymic.util.ramps import sigmoid_rampup +from pymic.net_run_wsl.wsl_abstract import WSLSegAgent +from pymic.util.ramps import get_rampup_ratio -class WSL_GatedCRF(WSL_EntropyMinimization): +class WSLGatedCRF(WSLSegAgent): """ - Training and testing agent for semi-supervised segmentation + Implementation of the Gated CRF Loss for Weakly Supervised Semantic Image Segmentation. + Anton Obukhov, Stamatios Georgoulis, Dengxin Dai, Luc Van Gool: + Gated CRF Loss for Weakly Supervised Semantic Image Segmentation. + CoRR, abs/1906.04651, 2019 + http://arxiv.org/abs/1906.04651 + } """ def __init__(self, config, stage = 'train'): - super(WSL_GatedCRF, self).__init__(config, stage) + super(WSLGatedCRF, self).__init__(config, stage) # parameters for gated CRF wsl_cfg = self.config['weakly_supervised_learning'] w0 = wsl_cfg.get('GatedCRFLoss_W0'.lower(), 1.0) @@ -36,6 +38,9 @@ def training(self): class_num = self.config['network']['class_num'] iter_valid = self.config['training']['iter_valid'] wsl_cfg = self.config['weakly_supervised_learning'] + iter_max = self.config['training']['iter_max'] + rampup_start = wsl_cfg.get('rampup_start', 0) + rampup_end = wsl_cfg.get('rampup_end', iter_max) train_loss = 0 train_loss_sup = 0 train_loss_reg = 0 @@ -79,18 +84,15 @@ def training(self): loss_reg = gatecrf_loss(outputs_soft, self.kernels, self.radius, batch_dict,input_shape[-2], input_shape[-1])["loss"] - iter_max = self.config['training']['iter_max'] - ramp_up_length = wsl_cfg.get('ramp_up_length', iter_max) - regular_w = 0.0 - if(self.glob_it > wsl_cfg.get('iter_sup', 0)): - regular_w = wsl_cfg.get('regularize_w', 0.1) - if(ramp_up_length is not None and self.glob_it < ramp_up_length): - regular_w = regular_w * sigmoid_rampup(self.glob_it, ramp_up_length) + rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") + regular_w = wsl_cfg.get('regularize_w', 0.1) * rampup_ratio loss = loss_sup + regular_w*loss_reg loss.backward() self.optimizer.step() - self.scheduler.step() + if(self.scheduler is not None and \ + not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): + self.scheduler.step() train_loss = train_loss + loss.item() train_loss_sup = train_loss_sup + loss_sup.item() diff --git a/pymic/net_run_wsl/wsl_main.py b/pymic/net_run_wsl/wsl_main.py index 595aa3e..abedb6b 100644 --- a/pymic/net_run_wsl/wsl_main.py +++ b/pymic/net_run_wsl/wsl_main.py @@ -5,19 +5,19 @@ import os import sys from pymic.util.parse_config import * -from pymic.net_run_wsl.wsl_em import WSL_EntropyMinimization -from pymic.net_run_wsl.wsl_gatedcrf import WSL_GatedCRF -from pymic.net_run_wsl.wsl_mumford_shah import WSL_MumfordShah -from pymic.net_run_wsl.wsl_tv import WSL_TotalVariation -from pymic.net_run_wsl.wsl_ustm import WSL_USTM -from pymic.net_run_wsl.wsl_dmpls import WSL_DMPLS +from pymic.net_run_wsl.wsl_em import WSLEntropyMinimization +from pymic.net_run_wsl.wsl_gatedcrf import WSLGatedCRF +from pymic.net_run_wsl.wsl_mumford_shah import WSLMumfordShah +from pymic.net_run_wsl.wsl_tv import WSLTotalVariation +from pymic.net_run_wsl.wsl_ustm import WSLUSTM +from pymic.net_run_wsl.wsl_dmpls import WSLDMPLS -WSLMethodDict = {'EntropyMinimization': WSL_EntropyMinimization, - 'GatedCRF': WSL_GatedCRF, - 'MumfordShah': WSL_MumfordShah, - 'TotalVariation': WSL_TotalVariation, - 'USTM': WSL_USTM, - 'DMPLS': WSL_DMPLS} +WSLMethodDict = {'EntropyMinimization': WSLEntropyMinimization, + 'GatedCRF': WSLGatedCRF, + 'MumfordShah': WSLMumfordShah, + 'TotalVariation': WSLTotalVariation, + 'USTM': WSLUSTM, + 'DMPLS': WSLDMPLS} def main(): if(len(sys.argv) < 3): @@ -31,7 +31,7 @@ def main(): log_dir = config['training']['ckpt_save_dir'] if(not os.path.exists(log_dir)): os.mkdir(log_dir) - logging.basicConfig(filename=log_dir+"/log.txt", level=logging.INFO, + logging.basicConfig(filename=log_dir+"/log_{0:}.txt".format(stage), level=logging.INFO, format='%(message)s') logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) logging_config(config) diff --git a/pymic/net_run_wsl/wsl_mumford_shah.py b/pymic/net_run_wsl/wsl_mumford_shah.py index 909a65b..df4c68f 100644 --- a/pymic/net_run_wsl/wsl_mumford_shah.py +++ b/pymic/net_run_wsl/wsl_mumford_shah.py @@ -2,29 +2,32 @@ from __future__ import print_function, division import logging import numpy as np -import random import torch -import torchvision.transforms as transforms -from pymic.io.nifty_dataset import NiftyDataset +from torch.optim import lr_scheduler from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.loss.seg.mumford_shah import MumfordShahLoss -from pymic.net_run.agent_seg import SegmentationAgent -from pymic.net_run_wsl.wsl_em import WSL_EntropyMinimization -from pymic.util.ramps import sigmoid_rampup +from pymic.net_run_wsl.wsl_abstract import WSLSegAgent +from pymic.util.ramps import get_rampup_ratio -class WSL_MumfordShah(WSL_EntropyMinimization): +class WSLMumfordShah(WSLSegAgent): """ - Training and testing agent for semi-supervised segmentation + Weakly supervised learning with Mumford Shah Loss according to this paper: + Boah Kim and Jong Chul Ye: Mumford–Shah Loss Functional + for Image Segmentation With Deep Learning. IEEE TIP, 2019. + https://doi.org/10.1109/TIP.2019.2941265 """ def __init__(self, config, stage = 'train'): - super(WSL_MumfordShah, self).__init__(config, stage) + super(WSLMumfordShah, self).__init__(config, stage) def training(self): class_num = self.config['network']['class_num'] iter_valid = self.config['training']['iter_valid'] wsl_cfg = self.config['weakly_supervised_learning'] + iter_max = self.config['training']['iter_max'] + rampup_start = wsl_cfg.get('rampup_start', 0) + rampup_end = wsl_cfg.get('rampup_end', iter_max) train_loss = 0 train_loss_sup = 0 train_loss_reg = 0 @@ -54,18 +57,15 @@ def training(self): loss_dict = {"prediction":outputs, 'image':inputs} loss_reg = reg_loss_calculator(loss_dict) - iter_max = self.config['training']['iter_max'] - ramp_up_length = wsl_cfg.get('ramp_up_length', iter_max) - regular_w = 0.0 - if(self.glob_it > wsl_cfg.get('iter_sup', 0)): - regular_w = wsl_cfg.get('regularize_w', 0.1) - if(ramp_up_length is not None and self.glob_it < ramp_up_length): - regular_w = regular_w * sigmoid_rampup(self.glob_it, ramp_up_length) + rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") + regular_w = wsl_cfg.get('regularize_w', 0.1) * rampup_ratio loss = loss_sup + regular_w*loss_reg # if (self.config['training']['use']) loss.backward() self.optimizer.step() - self.scheduler.step() + if(self.scheduler is not None and \ + not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): + self.scheduler.step() train_loss = train_loss + loss.item() train_loss_sup = train_loss_sup + loss_sup.item() diff --git a/pymic/net_run_wsl/wsl_tv.py b/pymic/net_run_wsl/wsl_tv.py index f11c5e0..2e56cb4 100644 --- a/pymic/net_run_wsl/wsl_tv.py +++ b/pymic/net_run_wsl/wsl_tv.py @@ -2,29 +2,30 @@ from __future__ import print_function, division import logging import numpy as np -import random import torch -import torchvision.transforms as transforms -from pymic.io.nifty_dataset import NiftyDataset +from torch.optim import lr_scheduler from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice from pymic.loss.seg.ssl import TotalVariationLoss -from pymic.net_run.agent_seg import SegmentationAgent -from pymic.net_run_wsl.wsl_em import WSL_EntropyMinimization -from pymic.util.ramps import sigmoid_rampup +from pymic.net_run_wsl.wsl_abstract import WSLSegAgent +from pymic.util.ramps import get_rampup_ratio +from pymic.util.general import keyword_match -class WSL_TotalVariation(WSL_EntropyMinimization): +class WSLTotalVariation(WSLSegAgent): """ - Training and testing agent for semi-supervised segmentation + Weakly suepervised segmentation with Total Variation Regularization. """ def __init__(self, config, stage = 'train'): - super(WSL_TotalVariation, self).__init__(config, stage) + super(WSLTotalVariation, self).__init__(config, stage) def training(self): class_num = self.config['network']['class_num'] iter_valid = self.config['training']['iter_valid'] wsl_cfg = self.config['weakly_supervised_learning'] + iter_max = self.config['training']['iter_max'] + rampup_start = wsl_cfg.get('rampup_start', 0) + rampup_end = wsl_cfg.get('rampup_end', iter_max) train_loss = 0 train_loss_sup = 0 train_loss_reg = 0 @@ -52,18 +53,15 @@ def training(self): loss_dict = {"prediction":outputs, 'softmax':True} loss_reg = TotalVariationLoss()(loss_dict) - iter_max = self.config['training']['iter_max'] - ramp_up_length = wsl_cfg.get('ramp_up_length', iter_max) - regular_w = 0.0 - if(self.glob_it > wsl_cfg.get('iter_sup', 0)): - regular_w = wsl_cfg.get('regularize_w', 0.1) - if(ramp_up_length is not None and self.glob_it < ramp_up_length): - regular_w = regular_w * sigmoid_rampup(self.glob_it, ramp_up_length) + rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") + regular_w = wsl_cfg.get('regularize_w', 0.1) * rampup_ratio loss = loss_sup + regular_w*loss_reg # if (self.config['training']['use']) loss.backward() self.optimizer.step() - self.scheduler.step() + if(self.scheduler is not None and \ + not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): + self.scheduler.step() train_loss = train_loss + loss.item() train_loss_sup = train_loss_sup + loss_sup.item() diff --git a/pymic/net_run_wsl/wsl_ustm.py b/pymic/net_run_wsl/wsl_ustm.py index c7306e8..0a2f7e1 100644 --- a/pymic/net_run_wsl/wsl_ustm.py +++ b/pymic/net_run_wsl/wsl_ustm.py @@ -5,27 +5,30 @@ import random import torch import torch.nn.functional as F -import torchvision.transforms as transforms -from pymic.io.nifty_dataset import NiftyDataset +from torch.optim import lr_scheduler from pymic.loss.seg.util import get_soft_label from pymic.loss.seg.util import reshape_prediction_and_ground_truth from pymic.loss.seg.util import get_classwise_dice -from pymic.loss.seg.ssl import EntropyLoss from pymic.net.net_dict_seg import SegNetDict -from pymic.net_run_wsl.wsl_em import WSL_EntropyMinimization -from pymic.transform.trans_dict import TransformDict -from pymic.util.ramps import sigmoid_rampup +from pymic.net_run_wsl.wsl_abstract import WSLSegAgent +from pymic.util.ramps import get_rampup_ratio +from pymic.util.general import keyword_match -class WSL_USTM(WSL_EntropyMinimization): +class WSLUSTM(WSLSegAgent): """ - Training and testing agent for semi-supervised segmentation + USTM for scribble-supervised segmentation according to the following paper: + Xiaoming Liu, Quan Yuan, Yaozong Gao, Helei He, Shuo Wang, Xiao Tang, + Jinshan Tang, Dinggang Shen: + Weakly Supervised Segmentation of COVID19 Infection with Scribble Annotation on CT Images. + Patter Recognition, 2022. + https://doi.org/10.1016/j.patcog.2021.108341 """ def __init__(self, config, stage = 'train'): - super(WSL_USTM, self).__init__(config, stage) + super(WSLUSTM, self).__init__(config, stage) self.net_ema = None def create_network(self): - super(WSL_USTM, self).create_network() + super(WSLUSTM, self).create_network() if(self.net_ema is None): net_name = self.config['network']['net_type'] if(net_name not in SegNetDict): @@ -40,6 +43,9 @@ def training(self): class_num = self.config['network']['class_num'] iter_valid = self.config['training']['iter_valid'] wsl_cfg = self.config['weakly_supervised_learning'] + iter_max = self.config['training']['iter_max'] + rampup_start = wsl_cfg.get('rampup_start', 0) + rampup_end = wsl_cfg.get('rampup_end', iter_max) train_loss = 0 train_loss_sup = 0 train_loss_reg = 0 @@ -95,24 +101,20 @@ def training(self): uncertainty = -1.0 * torch.sum(preds*torch.log(preds + 1e-6), dim=1, keepdim=True) - iter_max = self.config['training']['iter_max'] - ramp_up_length = wsl_cfg.get('ramp_up_length', iter_max) - threshold_ramp = sigmoid_rampup(self.glob_it, iter_max) + rampup_ratio = get_rampup_ratio(self.glob_it, rampup_start, rampup_end, "sigmoid") class_num = list(y.shape)[1] - threshold = (0.75+0.25*threshold_ramp)*np.log(class_num) + threshold = (0.75+0.25*rampup_ratio)*np.log(class_num) mask = (uncertainty < threshold).float() loss_reg = torch.sum(mask*square_error)/(2*torch.sum(mask)+1e-16) - regular_w = 0.0 - if(self.glob_it > wsl_cfg.get('iter_sup', 0)): - regular_w = wsl_cfg.get('regularize_w', 0.1) - if(ramp_up_length is not None and self.glob_it < ramp_up_length): - regular_w = regular_w * sigmoid_rampup(self.glob_it, ramp_up_length) + regular_w = wsl_cfg.get('regularize_w', 0.1) * rampup_ratio loss = loss_sup + regular_w*loss_reg loss.backward() self.optimizer.step() - self.scheduler.step() + if(self.scheduler is not None and \ + not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): + self.scheduler.step() # update EMA alpha = wsl_cfg.get('ema_decay', 0.99) diff --git a/pymic/transform/gamma_correction.py b/pymic/transform/gamma_correction.py deleted file mode 100644 index 4a88f1c..0000000 --- a/pymic/transform/gamma_correction.py +++ /dev/null @@ -1,40 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import print_function, division - -import torch -import json -import math -import random -import numpy as np -from scipy import ndimage -from pymic.transform.abstract_transform import AbstractTransform -from pymic.util.image_process import * - - -class ChannelWiseGammaCorrection(AbstractTransform): - """ - apply random gamma correction to each channel - """ - def __init__(self, params): - """ - (gamma_min, gamma_max) specify the range of gamma - """ - super(ChannelWiseGammaCorrection, self).__init__(params) - self.gamma_min = params['ChannelWiseGammaCorrection_gamma_min'.lower()] - self.gamma_max = params['ChannelWiseGammaCorrection_gamma_max'.lower()] - self.inverse = params.get('ChannelWiseGammaCorrection_inverse'.lower(), False) - - def __call__(self, sample): - image= sample['image'] - for chn in range(image.shape[0]): - gamma_c = random.random() * (self.gamma_max - self.gamma_min) + self.gamma_min - img_c = image[chn] - v_min = img_c.min() - v_max = img_c.max() - img_c = (img_c - v_min)/(v_max - v_min) - img_c = np.power(img_c, gamma_c)*(v_max - v_min) + v_min - image[chn] = img_c - - sample['image'] = image - return sample - diff --git a/pymic/transform/intensity.py b/pymic/transform/intensity.py new file mode 100644 index 0000000..b9e6070 --- /dev/null +++ b/pymic/transform/intensity.py @@ -0,0 +1,70 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division + +import torch +import json +import math +import random +import numpy as np +from scipy import ndimage +from pymic.transform.abstract_transform import AbstractTransform +from pymic.util.image_process import * + + +class GammaCorrection(AbstractTransform): + """ + apply random gamma correction to each channel + """ + def __init__(self, params): + """ + (gamma_min, gamma_max) specify the range of gamma + """ + super(GammaCorrection, self).__init__(params) + self.channels = params['GammaCorrection_channels'.lower()] + self.gamma_min = params['GammaCorrection_gamma_min'.lower()] + self.gamma_max = params['GammaCorrection_gamma_max'.lower()] + self.prob = params.get('GammaCorrection_probability'.lower(), 0.5) + self.inverse = params.get('GammaCorrection_inverse'.lower(), False) + + def __call__(self, sample): + if(np.random.uniform() > self.prob): + return sample + image= sample['image'] + for chn in self.channels: + gamma_c = random.random() * (self.gamma_max - self.gamma_min) + self.gamma_min + img_c = image[chn] + v_min = img_c.min() + v_max = img_c.max() + img_c = (img_c - v_min)/(v_max - v_min) + img_c = np.power(img_c, gamma_c)*(v_max - v_min) + v_min + image[chn] = img_c + + sample['image'] = image + return sample + +class GaussianNoise(AbstractTransform): + """ + apply random gamma correction to each channel + """ + def __init__(self, params): + """ + (gamma_min, gamma_max) specify the range of gamma + """ + super(GaussianNoise, self).__init__(params) + self.channels = params['GaussianNoise_channels'.lower()] + self.mean = params['GaussianNoise_mean'.lower()] + self.std = params['GaussianNoise_std'.lower()] + self.prob = params.get('GaussianNoise_probability'.lower(), 0.5) + self.inverse = params.get('GaussianNoise_inverse'.lower(), False) + + def __call__(self, sample): + if(np.random.uniform() > self.prob): + return sample + image= sample['image'] + for chn in self.channels: + img_c = image[chn] + noise = np.random.normal(self.mean, self.std, img_c.shape) + image[chn] = img_c + noise + + sample['image'] = image + return sample diff --git a/pymic/transform/trans_dict.py b/pymic/transform/trans_dict.py index ae9ce9c..d90e431 100644 --- a/pymic/transform/trans_dict.py +++ b/pymic/transform/trans_dict.py @@ -1,8 +1,9 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division -from pymic.transform.gamma_correction import ChannelWiseGammaCorrection +from pymic.transform.intensity import * from pymic.transform.gray2rgb import GrayscaleToRGB from pymic.transform.flip import RandomFlip +from pymic.transform.intensity import GaussianNoise from pymic.transform.pad import Pad from pymic.transform.rotate import RandomRotate from pymic.transform.rescale import Rescale, RandomRescale @@ -12,12 +13,13 @@ from pymic.transform.label_convert import * TransformDict = { - 'ChannelWiseGammaCorrection': ChannelWiseGammaCorrection, 'ChannelWiseThreshold': ChannelWiseThreshold, 'ChannelWiseThresholdWithNormalize': ChannelWiseThresholdWithNormalize, 'CropWithBoundingBox': CropWithBoundingBox, 'CenterCrop': CenterCrop, 'GrayscaleToRGB': GrayscaleToRGB, + 'GammaCorrection': GammaCorrection, + 'GaussianNoise': GaussianNoise, 'LabelConvert': LabelConvert, 'LabelConvertNonzero': LabelConvertNonzero, 'LabelToProbability': LabelToProbability, diff --git a/pymic/util/average_model.py b/pymic/util/average_model.py index 0b6fb29..73a537f 100644 --- a/pymic/util/average_model.py +++ b/pymic/util/average_model.py @@ -1,3 +1,4 @@ + import torch checkpoint_name1 = "/home/guotai/projects/PyMIC/examples/brats/model/casecade/wt/unet3d_4_8000.pt" diff --git a/pymic/util/evaluation_seg.py b/pymic/util/evaluation_seg.py index 61ae51c..b04880a 100644 --- a/pymic/util/evaluation_seg.py +++ b/pymic/util/evaluation_seg.py @@ -74,7 +74,7 @@ def get_edge_points(img): return edge -def binary_hausdorff95(s, g, spacing = None): +def binary_hd95(s, g, spacing = None): """ get the hausdorff distance between a binary segmentation and the ground truth inputs: @@ -165,8 +165,8 @@ def get_binary_evaluation_score(s_volume, g_volume, spacing, metric): elif(metric_lower == 'assd'): score = binary_assd(s_volume, g_volume, spacing) - elif(metric_lower == "hausdorff95"): - score = binary_hausdorff95(s_volume, g_volume, spacing) + elif(metric_lower == "hd95"): + score = binary_hd95(s_volume, g_volume, spacing) elif(metric_lower == "rve"): score = binary_relative_volume_error(s_volume, g_volume) diff --git a/pymic/util/general.py b/pymic/util/general.py new file mode 100644 index 0000000..063d654 --- /dev/null +++ b/pymic/util/general.py @@ -0,0 +1,27 @@ +# -*- coding: utf-8 -*- +from __future__ import print_function, division +import torch +import numpy as np + +def keyword_match(a,b): + return a.lower() == b.lower() + +def get_one_hot_seg(label, class_num): + """ + convert a segmentation label to one-hot + label: a tensor with a shape of [N, 1, D, H, W] or [N, 1, H, W] + class_num: class number. + output: an one-hot tensor with a shape of [N, C, D, H, W] or [N, C, H, W] + """ + size = list(label.size()) + if(size[1] != 1): + raise ValueError("The channel should be 1, \ + rather than {0:} before one-hot encoding".format(size[1])) + label = label.view(-1) + ones = torch.sparse.torch.eye(class_num).to(label.device) + one_hot = ones.index_select(0, label) + size.append(class_num) + one_hot = one_hot.view(*size) + one_hot = torch.transpose(one_hot, 1, -1) + one_hot = torch.squeeze(one_hot, -1) + return one_hot \ No newline at end of file diff --git a/pymic/util/image_process.py b/pymic/util/image_process.py index 61b577b..896e8c1 100644 --- a/pymic/util/image_process.py +++ b/pymic/util/image_process.py @@ -107,25 +107,26 @@ def crop_and_pad_ND_array_to_desired_shape(image, out_shape, pad_mod): return image_pad -def get_largest_component(image): +def get_largest_k_components(image, k = 1): """ - get the largest component from 2D or 3D binary image + get the largest K components from 2D or 3D binary image image: nd array """ dim = len(image.shape) if(image.sum() == 0 ): print('the largest component is null') return image - if(dim == 2): - s = ndimage.generate_binary_structure(2,1) - elif(dim == 3): - s = ndimage.generate_binary_structure(3,1) - else: + if(dim < 2 or dim > 3): raise ValueError("the dimension number should be 2 or 3") + s = ndimage.generate_binary_structure(dim,1) labeled_array, numpatches = ndimage.label(image, s) sizes = ndimage.sum(image, labeled_array, range(1, numpatches + 1)) - max_label = np.where(sizes == sizes.max())[0] + 1 - output = np.asarray(labeled_array == max_label, np.uint8) + sizes_sort = sorted(sizes, reverse = True) + kmin = min(k, numpatches) + output = np.zeros_like(image) + for i in range(kmin): + labeli = np.where(sizes == sizes_sort[i])[0] + 1 + output = output + np.asarray(labeled_array == labeli, np.uint8) return output def get_euclidean_distance(image, dim = 3, spacing = [1.0, 1.0, 1.0]): diff --git a/pymic/util/post_process.py b/pymic/util/post_process.py new file mode 100644 index 0000000..da133ca --- /dev/null +++ b/pymic/util/post_process.py @@ -0,0 +1,40 @@ +# -*- coding: utf-8 -*- +from __future__ import absolute_import, print_function + +import os +import numpy as np +import SimpleITK as sitk +from pymic.util.image_process import get_largest_k_components + +class PostProcess(object): + def __init__(self, params): + self.params = params + + def __call__(self, seg): + return seg + +class PostKeepLargestComponent(PostProcess): + def __init__(self, params): + super(PostKeepLargestComponent, self).__init__(params) + self.mode = params.get("KeepLargestComponent_mode".lower(), 1) + """ + mode = 1: keep the largest component of the union of foreground classes. + mode = 2: keep the largest component for each foreground class. + """ + + def __call__(self, seg): + if(self.mode == 1): + mask = np.asarray(seg > 0, np.uint8) + mask = get_largest_k_components(mask) + seg = seg * mask + elif(self.mode == 2): + class_num = seg.max() + output = np.zeros_like(seg) + for c in range(1, class_num + 1): + seg_c = np.asarray(seg == c, np.uint8) + seg_c = get_largest_k_components(seg_c) + output = output + seg_c * c + return seg + +PostProcessDict = { + 'KeepLargestComponent': PostKeepLargestComponent} \ No newline at end of file diff --git a/pymic/util/ramps.py b/pymic/util/ramps.py index e344cfe..b58adb6 100644 --- a/pymic/util/ramps.py +++ b/pymic/util/ramps.py @@ -10,24 +10,21 @@ 0 and 1. """ -def sigmoid_rampup(i, length): - """Exponential rampup from https://arxiv.org/abs/1610.02242""" - if length == 0: - return 1.0 - else: - i = np.clip(i, 0.0, length) - phase = 1.0 - (i + 0.0) / length - return float(np.exp(-5.0 * phase * phase)) - -def linear_rampup(i, length): - """Linear rampup""" - assert i >= 0 and length >= 0 - i = np.clip(i, 0.0, length) - return (i + 0.0) / length +def get_rampup_ratio(i, start, end, mode = "linear"): + if( i < start): + rampup = 0.0 + elif(i > end): + rampup = 1.0 + elif(mode == "linear"): + rampup = (i - start) / (end - start) + elif(mode == "sigmoid"): + phase = 1.0 - (i - start) / (end - start) + rampup = float(np.exp(-5.0 * phase * phase)) + return rampup -def cosine_rampdown(i, length): +def cosine_rampdown(i, start, end): """Cosine rampdown from https://arxiv.org/abs/1608.03983""" i = np.clip(i, 0.0, length) return float(.5 * (np.cos(np.pi * i / length) + 1)) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..2dc1604 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,12 @@ +matplotlib>=3.1.2 +numpy>=1.17.4 +pandas>=0.25.3 +python>=3.6 +scikit-image>=0.16.2 +scikit-learn>=0.22 +scipy>=1.3.3 +SimpleITK>=1.2.4 +tensorboard>=2.1.0 +tensorboardX>=1.9 +torch>=1.7.1 +torchvision>=0.8.2 diff --git a/setup.py b/setup.py index 498aa26..ce7271b 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ # Get the summary description = 'An open-source deep learning platform' + \ - ' for medical image computing' + ' for annotation-efficient medical image computing' # Get the long description with open('README.md', encoding='utf-8') as f: @@ -11,7 +11,7 @@ setuptools.setup( name = 'PYMIC', - version = "0.2.5", + version = "0.3.0", author ='PyMIC Consortium', author_email = 'wguotai@gmail.com', description = description, @@ -31,6 +31,8 @@ 'console_scripts': [ 'pymic_run = pymic.net_run.net_run:main', 'pymic_ssl = pymic.net_run_ssl.ssl_main:main', + 'pymic_wsl = pymic.net_run_wsl.wsl_main:main', + 'pymic_nll = pymic.net_run_nll.nll_main:main', 'pymic_eval_cls = pymic.util.evaluation_cls:main', 'pymic_eval_seg = pymic.util.evaluation_seg:main' ],