From 55d2edc3f982a9be0e0fa739542faec28f8180f5 Mon Sep 17 00:00:00 2001 From: taigw Date: Sat, 1 Oct 2022 10:23:44 +0800 Subject: [PATCH 1/8] update classification agent --- pymic/loss/seg/ce.py | 1 - pymic/net_run/agent_cls.py | 8 +++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/pymic/loss/seg/ce.py b/pymic/loss/seg/ce.py index bbe6d02..529482b 100644 --- a/pymic/loss/seg/ce.py +++ b/pymic/loss/seg/ce.py @@ -18,7 +18,6 @@ class CrossEntropyLoss(AbstractSegLoss): """ def __init__(self, params = None): super(CrossEntropyLoss, self).__init__(params) - def forward(self, loss_input_dict): predict = loss_input_dict['prediction'] diff --git a/pymic/net_run/agent_cls.py b/pymic/net_run/agent_cls.py index 8ab7729..46bb2d7 100644 --- a/pymic/net_run/agent_cls.py +++ b/pymic/net_run/agent_cls.py @@ -145,7 +145,8 @@ def training(self): self.optimizer.zero_grad() # forward + backward + optimize outputs = self.net(inputs) - loss = self.get_loss_value(data, inputs, outputs, labels) + + loss = self.get_loss_value(data, outputs, labels) loss.backward() self.optimizer.step() self.scheduler.step() @@ -175,7 +176,7 @@ def validation(self): self.optimizer.zero_grad() # forward + backward + optimize outputs = self.net(inputs) - loss = self.get_loss_value(data, inputs, outputs, labels) + loss = self.get_loss_value(data, outputs, labels) # statistics sample_num += labels.size(0) @@ -243,10 +244,11 @@ def train_valid(self): logging.info("{0:} training start".format(str(datetime.now())[:-7])) self.summ_writer = SummaryWriter(self.config['training']['ckpt_save_dir']) for it in range(iter_start, iter_max, iter_valid): + lr_value = self.optimizer.param_groups[0]['lr'] train_scalars = self.training() valid_scalars = self.validation() glob_it = it + iter_valid - self.write_scalars(train_scalars, valid_scalars, glob_it) + self.write_scalars(train_scalars, valid_scalars, lr_value, glob_it) if(valid_scalars[metrics] > self.max_val_score): self.max_val_score = valid_scalars[metrics] From ac1c8fcb63811780c4a7597c36e599f5db51470e Mon Sep 17 00:00:00 2001 From: taigw Date: Sat, 1 Oct 2022 10:59:26 +0800 Subject: [PATCH 2/8] Update torch_pretrained_net.py --- pymic/net/cls/torch_pretrained_net.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pymic/net/cls/torch_pretrained_net.py b/pymic/net/cls/torch_pretrained_net.py index edc5d9c..5017f72 100644 --- a/pymic/net/cls/torch_pretrained_net.py +++ b/pymic/net/cls/torch_pretrained_net.py @@ -75,7 +75,7 @@ def __init__(self, params): def get_parameters_to_update(self): if(self.update_mode == "all"): return self.net.parameters() - elif(self.update_layers == "last"): + elif(self.update_mode == "last"): params = self.net.fc.parameters() if(self.in_chns !=3): # combining the two iterables into a single one @@ -119,7 +119,7 @@ def get_parameters_to_update(self): params = self.net.classifier[-1].parameters() if(self.in_chns !=3): params = itertools.chain() - for pram in [self.net.classifier[-1].parameters(), self.net.net.features[0].parameters()]: + for pram in [self.net.classifier[-1].parameters(), self.net.features[0].parameters()]: params = itertools.chain(params, pram) return params else: @@ -138,7 +138,7 @@ class MobileNetV2(BuiltInNet): as well as the first layer when `input_chns` is not 3. """ def __init__(self, params): - super(MobileNetV2, self).__init__() + super(MobileNetV2, self).__init__(params) self.net = models.mobilenet_v2(pretrained = self.pretrain) # replace the last layer @@ -157,7 +157,7 @@ def get_parameters_to_update(self): params = self.net.classifier[-1].parameters() if(self.in_chns !=3): params = itertools.chain() - for pram in [self.net.classifier[-1].parameters(), self.net.net.features[0][0].parameters()]: + for pram in [self.net.classifier[-1].parameters(), self.net.features[0][0].parameters()]: params = itertools.chain(params, pram) return params else: From c5415254375657ea35231bbe1a98caaafef2f336 Mon Sep 17 00:00:00 2001 From: taigw Date: Sat, 1 Oct 2022 14:06:34 +0800 Subject: [PATCH 3/8] enable ReduceLROnPlateau --- pymic/net_run/agent_cls.py | 43 +++++++++++++++++++++++++++----------- pymic/net_run/agent_seg.py | 2 +- 2 files changed, 32 insertions(+), 13 deletions(-) diff --git a/pymic/net_run/agent_cls.py b/pymic/net_run/agent_cls.py index 46bb2d7..cbb79d4 100644 --- a/pymic/net_run/agent_cls.py +++ b/pymic/net_run/agent_cls.py @@ -5,11 +5,12 @@ import csv import logging import time -import torch -from torchvision import transforms import numpy as np +import torch import torch.nn as nn from datetime import datetime +from torch.optim import lr_scheduler +from torchvision import transforms from tensorboardX import SummaryWriter from pymic.io.nifty_dataset import ClassificationDataset from pymic.loss.loss_dict_cls import PyMICClsLossDict @@ -149,7 +150,9 @@ def training(self): loss = self.get_loss_value(data, outputs, labels) loss.backward() self.optimizer.step() - self.scheduler.step() + if(self.scheduler is not None and \ + not isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): + self.scheduler.step() # statistics sample_num += labels.size(0) @@ -185,7 +188,9 @@ def validation(self): avg_loss = running_loss / sample_num avg_score= running_score.double() / sample_num - metrics =self.config['training'].get("evaluation_metric", "accuracy") + metrics = self.config['training'].get("evaluation_metric", "accuracy") + if(isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau)): + self.scheduler.step(avg_score) valid_scalers = {'loss': avg_loss, metrics: avg_score} return valid_scalers @@ -222,7 +227,15 @@ def train_valid(self): iter_max = self.config['training']['iter_max'] iter_valid = self.config['training']['iter_valid'] iter_save = self.config['training']['iter_save'] + early_stop_it = self.config['training'].get('early_stop_patience', None) metrics = self.config['training'].get("evaluation_metric", "accuracy") + if(iter_save is None): + iter_save_list = [iter_max] + elif(isinstance(iter_save, (tuple, list))): + iter_save_list = iter_save + else: + iter_save_list = range(0, iter_max + 1, iter_save) + self.max_val_score = 0.0 self.max_val_it = 0 self.best_model_wts = None @@ -243,29 +256,35 @@ def train_valid(self): logging.info("{0:} training start".format(str(datetime.now())[:-7])) self.summ_writer = SummaryWriter(self.config['training']['ckpt_save_dir']) + self.glob_it = iter_start for it in range(iter_start, iter_max, iter_valid): lr_value = self.optimizer.param_groups[0]['lr'] train_scalars = self.training() valid_scalars = self.validation() - glob_it = it + iter_valid - self.write_scalars(train_scalars, valid_scalars, lr_value, glob_it) + self.glob_it = it + iter_valid + self.write_scalars(train_scalars, valid_scalars, lr_value, self.glob_it) if(valid_scalars[metrics] > self.max_val_score): self.max_val_score = valid_scalars[metrics] - self.max_val_it = glob_it + self.max_val_it = self.glob_it self.best_model_wts = copy.deepcopy(self.net.state_dict()) - if (glob_it % iter_save == 0): - save_dict = {'iteration': glob_it, + stop_now = True if(early_stop_it is not None and \ + self.glob_it - self.max_val_it > early_stop_it) else False + + if ((self.glob_it in iter_save_list) or stop_now): + save_dict = {'iteration': self.glob_it, 'valid_pred': valid_scalars[metrics], 'model_state_dict': self.net.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict()} - save_name = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, glob_it) + save_name = "{0:}/{1:}_{2:}.pt".format(ckpt_dir, ckpt_prefix, self.glob_it) torch.save(save_dict, save_name) txt_file = open("{0:}/{1:}_latest.txt".format(ckpt_dir, ckpt_prefix), 'wt') - txt_file.write(str(glob_it)) + txt_file.write(str(self.glob_it)) txt_file.close() - + if(stop_now): + logging.info("The training is early stopped") + break # save the best performing checkpoint save_dict = {'iteration': self.max_val_it, 'valid_pred': self.max_val_score, diff --git a/pymic/net_run/agent_seg.py b/pymic/net_run/agent_seg.py index bba4508..9220656 100644 --- a/pymic/net_run/agent_seg.py +++ b/pymic/net_run/agent_seg.py @@ -10,9 +10,9 @@ import numpy as np import torch.nn as nn import torch.optim as optim -from torch.optim import lr_scheduler import torch.nn.functional as F from datetime import datetime +from torch.optim import lr_scheduler from tensorboardX import SummaryWriter from pymic.io.image_read_write import save_nd_array_as_image from pymic.io.nifty_dataset import NiftyDataset From bd38bc983d000879196ae3070392e98a45909040 Mon Sep 17 00:00:00 2001 From: taigw Date: Sat, 1 Oct 2022 16:00:47 +0800 Subject: [PATCH 4/8] support mixup --- pymic/loss/cls/basic.py | 10 ++------ pymic/net_run/agent_cls.py | 21 ++++++++++------ pymic/net_run/agent_seg.py | 7 +++++- pymic/util/general.py | 50 +++++++++++++++++++++++++++++++++++++- 4 files changed, 71 insertions(+), 17 deletions(-) diff --git a/pymic/loss/cls/basic.py b/pymic/loss/cls/basic.py index 4c90943..56925fc 100644 --- a/pymic/loss/cls/basic.py +++ b/pymic/loss/cls/basic.py @@ -65,10 +65,7 @@ def forward(self, loss_input_dict): labels = loss_input_dict['ground_truth'][:, None] # reshape to N, 1 softmax = nn.Softmax(dim = 1) predict = softmax(predict) - num_class = list(predict.size())[1] - data_type = 'float' if(predict.dtype is torch.float32) else 'double' - soft_y = get_soft_label(labels, num_class, data_type) - loss = self.l1_loss(predict, soft_y) + loss = self.l1_loss(predict, labels) return loss class MSELoss(AbstractClassificationLoss): @@ -84,10 +81,7 @@ def forward(self, loss_input_dict): labels = loss_input_dict['ground_truth'][:, None] # reshape to N, 1 softmax = nn.Softmax(dim = 1) predict = softmax(predict) - num_class = list(predict.size())[1] - data_type = 'float' if(predict.dtype is torch.float32) else 'double' - soft_y = get_soft_label(labels, num_class, data_type) - loss = self.mse_loss(predict, soft_y) + loss = self.mse_loss(predict, labels) return loss class NLLLoss(AbstractClassificationLoss): diff --git a/pymic/net_run/agent_cls.py b/pymic/net_run/agent_cls.py index cbb79d4..f9f7781 100644 --- a/pymic/net_run/agent_cls.py +++ b/pymic/net_run/agent_cls.py @@ -9,6 +9,7 @@ import torch import torch.nn as nn from datetime import datetime +from random import random from torch.optim import lr_scheduler from torchvision import transforms from tensorboardX import SummaryWriter @@ -17,6 +18,7 @@ from pymic.net.net_dict_cls import TorchClsNetDict from pymic.transform.trans_dict import TransformDict from pymic.net_run.agent_abstract import NetRunAgent +from pymic.util.general import mixup import warnings warnings.filterwarnings('ignore', '.*output shape of zoom.*') @@ -111,16 +113,17 @@ def get_evaluation_score(self, outputs, labels): """ Get evaluation score for a prediction. - :param outputs: (tensor) Prediction obtained by a network. - :param labels: (tensor) The ground truth. + :param outputs: (tensor) Prediction obtained by a network with size N X C. + :param labels: (tensor) The ground truth with size N X C. """ metrics = self.config['training'].get("evaluation_metric", "accuracy") if(metrics != "accuracy"): # default classification accuracy raise ValueError("Not implemeted for metric {0:}".format(metrics)) if(self.task_type == "cls"): - _, preds = torch.max(outputs, 1) - consis= self.convert_tensor_type(preds == labels.data) - score = torch.mean(consis) + out_argmax = torch.argmax(outputs, 1) + lab_argmax = torch.argmax(labels, 1) + consis = self.convert_tensor_type(out_argmax == lab_argmax) + score = torch.mean(consis) elif(self.task_type == "cls_nexcl"): #nonexclusive classification preds = self.convert_tensor_type(outputs > 0.5) consis= self.convert_tensor_type(preds == labels.data) @@ -129,6 +132,7 @@ def get_evaluation_score(self, outputs, labels): def training(self): iter_valid = self.config['training']['iter_valid'] + mixup_prob = self.config['training'].get('mixup_probability', 0.5) sample_num = 0 running_loss = 0 running_score= 0 @@ -140,8 +144,11 @@ def training(self): self.trainIter = iter(self.train_loader) data = next(self.trainIter) inputs = self.convert_tensor_type(data['image']) - labels = data['label'].long() + labels = self.convert_tensor_type(data['label_prob']) + if(random() < mixup_prob): + inputs, labels = mixup(inputs, labels) inputs, labels = inputs.to(self.device), labels.to(self.device) + # zero the parameter gradients self.optimizer.zero_grad() # forward + backward + optimize @@ -174,7 +181,7 @@ def validation(self): self.net.eval() for data in validIter: inputs = self.convert_tensor_type(data['image']) - labels = data['label'].long() + labels = self.convert_tensor_type(data['label_prob']) inputs, labels = inputs.to(self.device), labels.to(self.device) self.optimizer.zero_grad() # forward + backward + optimize diff --git a/pymic/net_run/agent_seg.py b/pymic/net_run/agent_seg.py index 9220656..620f29b 100644 --- a/pymic/net_run/agent_seg.py +++ b/pymic/net_run/agent_seg.py @@ -12,6 +12,7 @@ import torch.optim as optim import torch.nn.functional as F from datetime import datetime +from random import random from torch.optim import lr_scheduler from tensorboardX import SummaryWriter from pymic.io.image_read_write import save_nd_array_as_image @@ -28,6 +29,7 @@ from pymic.transform.trans_dict import TransformDict from pymic.util.post_process import PostProcessDict from pymic.util.image_process import convert_label +from pymic.util.general import mixup class SegmentationAgent(NetRunAgent): def __init__(self, config, stage = 'train'): @@ -120,6 +122,7 @@ def set_postprocessor(self, postprocessor): def training(self): class_num = self.config['network']['class_num'] iter_valid = self.config['training']['iter_valid'] + mixup_prob = self.config['training'].get('mixup_probability', 0.5) train_loss = 0 train_dice_list = [] self.net.train() @@ -132,7 +135,9 @@ def training(self): # get the inputs inputs = self.convert_tensor_type(data['image']) labels_prob = self.convert_tensor_type(data['label_prob']) - + if(random() < mixup_prob): + inputs, labels_prob = mixup(inputs, labels_prob) + # # for debug # for i in range(inputs.shape[0]): # image_i = inputs[i][0] diff --git a/pymic/util/general.py b/pymic/util/general.py index 75b6af1..99eb49f 100644 --- a/pymic/util/general.py +++ b/pymic/util/general.py @@ -29,4 +29,52 @@ def get_one_hot_seg(label, class_num): one_hot = one_hot.view(*size) one_hot = torch.transpose(one_hot, 1, -1) one_hot = torch.squeeze(one_hot, -1) - return one_hot \ No newline at end of file + return one_hot + +def mixup(inputs, labels): + """Shuffle a minibatch and do linear interpolation between images and labels. + Both classification and segmentation labels are supported. The targets should + be one-hot labels. + + :param inputs: a tensor of input images with size N X C0 x H x W. + :param labels: a tensor of one-hot labels. The shape is N X C for classification + tasks, and N X C X H X W for segmentation tasks. + """ + input_shape = list(inputs.shape) + label_shape = list(labels.shape) + img_dim = len(input_shape) - 2 + N = input_shape[0] # batch size + C = label_shape[1] # class number + rp1 = torch.randperm(N) + inputs1 = inputs[rp1] + labels1 = labels[rp1] + + rp2 = torch.randperm(N) + inputs2 = inputs[rp2] + labels2 = labels[rp2] + + a = np.random.beta(1, 1, [N, 1]) + if(img_dim == 2): + b = np.tile(a[..., None, None], [1] + input_shape[1:]) + elif(img_dim == 3): + b = np.tile(a[..., None, None, None], [1] + input_shape[1:]) + else: + raise ValueError("MixUp only supports 2D and 3D images, but the " + + "input image has {0:} dimensions".format(img_dim)) + + inputs1 = inputs1 * torch.from_numpy(b).float() + inputs2 = inputs2 * torch.from_numpy(1 - b).float() + inputs_mix = inputs1 + inputs2 + + if(len(label_shape) == 2): # for classification tasks + c = np.tile(a, [1, C]) + elif(img_dim == 2): # for 2D segmentation tasks + c = np.tile(a[..., None, None], [1] + label_shape[1:]) + else: # for 3D segmentation tasks + c = np.tile(a[..., None, None, None], [1] + label_shape[1:]) + + labels1 = labels1 * torch.from_numpy(c).float() + labels2 = labels2 * torch.from_numpy(1 - c).float() + labels_mix = labels1 + labels2 + + return inputs_mix, labels_mix From 981d47acf0550fd93af3c15cde0ec31ba49e339b Mon Sep 17 00:00:00 2001 From: taigw Date: Tue, 22 Nov 2022 23:00:47 +0800 Subject: [PATCH 5/8] update network unet2d_urpc is not needed as deep supervision is supported by unet2d and unet3d now --- pymic/net/net2d/unet2d_urpc.py | 132 --------------------------------- pymic/net/net_dict_seg.py | 3 - 2 files changed, 135 deletions(-) delete mode 100644 pymic/net/net2d/unet2d_urpc.py diff --git a/pymic/net/net2d/unet2d_urpc.py b/pymic/net/net2d/unet2d_urpc.py deleted file mode 100644 index ee8ab7c..0000000 --- a/pymic/net/net2d/unet2d_urpc.py +++ /dev/null @@ -1,132 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import print_function, division - -import torch -import torch.nn as nn -import numpy as np -from torch.distributions.uniform import Uniform -from pymic.net.net2d.unet2d import ConvBlock, DownBlock, UpBlock - -def FeatureDropout(x): - attention = torch.mean(x, dim=1, keepdim=True) - max_val, _ = torch.max(attention.view( - x.size(0), -1), dim=1, keepdim=True) - threshold = max_val * np.random.uniform(0.7, 0.9) - threshold = threshold.view(x.size(0), 1, 1, 1).expand_as(attention) - drop_mask = (attention < threshold).float() - x = x.mul(drop_mask) - return x - -class FeatureNoise(nn.Module): - def __init__(self, uniform_range=0.3): - super(FeatureNoise, self).__init__() - self.uni_dist = Uniform(-uniform_range, uniform_range) - - def feature_based_noise(self, x): - noise_vector = self.uni_dist.sample( - x.shape[1:]).to(x.device).unsqueeze(0) - x_noise = x.mul(noise_vector) + x - return x_noise - - def forward(self, x): - x = self.feature_based_noise(x) - return x - -class UNet2D_URPC(nn.Module): - """ - An modification the U-Net to obtain multi-scale prediction according to - the URPC paper. - - * Reference: Xiangde Luo, Guotai Wang*, Wenjun Liao, Jieneng Chen, Tao Song, Yinan Chen, - Shichuan Zhang, Dimitris N. Metaxas, Shaoting Zhang. - Semi-Supervised Medical Image Segmentation via Uncertainty Rectified Pyramid Consistency . - `Medical Image Analysis 2022. `_ - - Also see: https://github.com/HiLab-git/SSL4MIS/blob/master/code/networks/unet.py - - Parameters are given in the `params` dictionary, and should include the - following fields: - - :param in_chns: (int) Input channel number. - :param feature_chns: (list) Feature channel for each resolution level. - The length should be 5, such as [16, 32, 64, 128, 256]. - :param dropout: (list) The dropout ratio for each resolution level. - The length should be the same as that of `feature_chns`. - :param class_num: (int) The class number for segmentation task. - :param bilinear: (bool) Using bilinear for up-sampling or not. - If False, deconvolution will be used for up-sampling. - """ - def __init__(self, params): - super(UNet2D_URPC, self).__init__() - self.params = params - self.in_chns = self.params['in_chns'] - self.ft_chns = self.params['feature_chns'] - self.dropout = self.params['dropout'] - self.n_class = self.params['class_num'] - self.bilinear = self.params['bilinear'] - assert(len(self.ft_chns) == 5) - - self.in_conv= ConvBlock(self.in_chns, self.ft_chns[0], self.dropout[0]) - self.down1 = DownBlock(self.ft_chns[0], self.ft_chns[1], self.dropout[1]) - self.down2 = DownBlock(self.ft_chns[1], self.ft_chns[2], self.dropout[2]) - self.down3 = DownBlock(self.ft_chns[2], self.ft_chns[3], self.dropout[3]) - self.down4 = DownBlock(self.ft_chns[3], self.ft_chns[4], self.dropout[4]) - self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], 0.0, self.bilinear) - self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], 0.0, self.bilinear) - self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], 0.0, self.bilinear) - self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], 0.0, self.bilinear) - - self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, - kernel_size = 3, padding = 1) - self.out_conv_dp4 = nn.Conv2d(self.ft_chns[4], self.n_class, - kernel_size=3, padding=1) - self.out_conv_dp3 = nn.Conv2d(self.ft_chns[3], self.n_class, - kernel_size=3, padding=1) - self.out_conv_dp2 = nn.Conv2d(self.ft_chns[2], self.n_class, - kernel_size=3, padding=1) - self.out_conv_dp1 = nn.Conv2d(self.ft_chns[1], self.n_class, - kernel_size=3, padding=1) - self.feature_noise = FeatureNoise() - - def forward(self, x): - x_shape = list(x.shape) - if(len(x_shape) == 5): - [N, C, D, H, W] = x_shape - new_shape = [N*D, C, H, W] - x = torch.transpose(x, 1, 2) - x = torch.reshape(x, new_shape) - x0 = self.in_conv(x) - x1 = self.down1(x0) - x2 = self.down2(x1) - x3 = self.down3(x2) - x4 = self.down4(x3) - - x = self.up1(x4, x3) - if self.training: - x = nn.functional.dropout(x, p=0.5) - dp3_out = self.out_conv_dp3(x) - - x = self.up2(x, x2) - if self.training: - x = FeatureDropout(x) - dp2_out = self.out_conv_dp2(x) - - x = self.up3(x, x1) - if self.training: - x = self.feature_noise(x) - dp1_out = self.out_conv_dp1(x) - - x = self.up4(x, x0) - dp0_out = self.out_conv(x) - - out_shape = list(dp0_out.shape)[2:] - dp3_out = nn.functional.interpolate(dp3_out, out_shape) - dp2_out = nn.functional.interpolate(dp2_out, out_shape) - dp1_out = nn.functional.interpolate(dp1_out, out_shape) - out = [dp0_out, dp1_out, dp2_out, dp3_out] - - if(len(x_shape) == 5): - new_shape = [N, D] + list(dp0_out.shape)[1:] - for i in range(len(out)): - out[i] = torch.transpose(torch.reshape(out[i], new_shape), 1, 2) - return out \ No newline at end of file diff --git a/pymic/net/net_dict_seg.py b/pymic/net/net_dict_seg.py index 5dd3bee..195896a 100644 --- a/pymic/net/net_dict_seg.py +++ b/pymic/net/net_dict_seg.py @@ -4,7 +4,6 @@ * UNet2D :mod:`pymic.net.net2d.unet2d.UNet2D` * UNet2D_DualBranch :mod:`pymic.net.net2d.unet2d_dual_branch.UNet2D_DualBranch` -* UNet2D_URPC :mod:`pymic.net.net2d.unet2d_urpc.UNet2D_URPC` * UNet2D_CCT :mod:`pymic.net.net2d.unet2d_cct.UNet2D_CCT` * UNet2D_ScSE :mod:`pymic.net.net2d.unet2d_scse.UNet2D_ScSE` * AttentionUNet2D :mod:`pymic.net.net2d.unet2d_attention.AttentionUNet2D` @@ -17,7 +16,6 @@ from __future__ import print_function, division from pymic.net.net2d.unet2d import UNet2D from pymic.net.net2d.unet2d_dual_branch import UNet2D_DualBranch -from pymic.net.net2d.unet2d_urpc import UNet2D_URPC from pymic.net.net2d.unet2d_cct import UNet2D_CCT from pymic.net.net2d.cople_net import COPLENet from pymic.net.net2d.unet2d_attention import AttentionUNet2D @@ -30,7 +28,6 @@ SegNetDict = { 'UNet2D': UNet2D, 'UNet2D_DualBranch': UNet2D_DualBranch, - 'UNet2D_URPC': UNet2D_URPC, 'UNet2D_CCT': UNet2D_CCT, 'COPLENet': COPLENet, 'AttentionUNet2D': AttentionUNet2D, From 647674dd30249162b0c20e6b71b969c7ee739d0d Mon Sep 17 00:00:00 2001 From: taigw Date: Tue, 6 Dec 2022 11:03:10 +0800 Subject: [PATCH 6/8] enable automatic installation of dependencies --- pymic/loss/seg/ssl.py | 5 +- pymic/util/evaluation_seg.py | 115 ++++++++++++++++++----------------- pymic/util/post_process.py | 1 + requirements.txt | 4 +- setup.py | 13 +++- 5 files changed, 77 insertions(+), 61 deletions(-) diff --git a/pymic/loss/seg/ssl.py b/pymic/loss/seg/ssl.py index f15fc60..0bf276f 100644 --- a/pymic/loss/seg/ssl.py +++ b/pymic/loss/seg/ssl.py @@ -6,8 +6,9 @@ import torch.nn as nn import numpy as np from pymic.loss.seg.util import reshape_tensor_to_2D +from pymic.loss.seg.abstract import AbstractSegLoss -class EntropyLoss(nn.Module): +class EntropyLoss(AbstractSegLoss): """ Entropy Minimization for segmentation tasks. The parameters should be written in the `params` dictionary, and it has the @@ -43,7 +44,7 @@ def forward(self, loss_input_dict): avg_ent = torch.mean(entropy) return avg_ent -class TotalVariationLoss(nn.Module): +class TotalVariationLoss(AbstractSegLoss): """ Total Variation Loss for segmentation tasks. The parameters should be written in the `params` dictionary, and it has the diff --git a/pymic/util/evaluation_seg.py b/pymic/util/evaluation_seg.py index ec02297..06fc402 100644 --- a/pymic/util/evaluation_seg.py +++ b/pymic/util/evaluation_seg.py @@ -260,14 +260,15 @@ def evaluation(config): Run evaluation of segmentation results based on a configuration dictionary `config`. The following fields should be provided in `config`: - :param metric: (str) The metric for evaluation. + :param metric_list: (list) The list of metrics for evaluation. The metric options are {`dice`, `iou`, `assd`, `hd95`, `rve`, `volume`}. :param label_list: (list) The list of labels for evaluation. :param label_fuse: (option, bool) If true, fuse the labels in the `label_list` as the foreground, and other labels as the background. Default is False. :param organ_name: (str) The name of the organ for segmentation. :param ground_truth_folder_root: (str) The root dir of ground truth images. - :param segmentation_folder_root: (str) The root dir of segmentation images. + :param segmentation_folder_root: (str or list) The root dir of segmentation images. + When a list is given, each list element should be the root dir of the results of one method. :param evaluation_image_pair: (str) The csv file that provide the segmentation images and the corresponding ground truth images. :param ground_truth_label_convert_source: (optional, list) The list of source @@ -280,7 +281,7 @@ def evaluation(config): labels for label conversion in the segmentation. """ - metric = config['metric'] + metric_list = config['metric_list'] label_list = config['label_list'] label_fuse = config.get('label_fuse', False) organ_name = config['organ_name'] @@ -295,60 +296,62 @@ def evaluation(config): segmentation_label_convert_target = config.get('segmentation_label_convert_target', None) image_items = pd.read_csv(image_pair_csv) - item_num = len(image_items) - for seg_root_n in seg_root: - score_all_data = [] - name_score_list= [] - for i in range(item_num): - gt_name = image_items.iloc[i, 0] - seg_name = image_items.iloc[i, 1] - # seg_name = seg_name.replace(".nii.gz", "_pred.nii.gz") - gt_full_name = gt_root + '/' + gt_name - seg_full_name = seg_root_n + '/' + seg_name - - s_dict = load_image_as_nd_array(seg_full_name) - g_dict = load_image_as_nd_array(gt_full_name) - s_volume = s_dict["data_array"]; s_spacing = s_dict["spacing"] - g_volume = g_dict["data_array"]; g_spacing = g_dict["spacing"] - # for dim in range(len(s_spacing)): - # assert(s_spacing[dim] == g_spacing[dim]) - if((ground_truth_label_convert_source is not None) and \ - ground_truth_label_convert_target is not None): - g_volume = convert_label(g_volume, ground_truth_label_convert_source, \ - ground_truth_label_convert_target) - - if((segmentation_label_convert_source is not None) and \ - segmentation_label_convert_target is not None): - s_volume = convert_label(s_volume, segmentation_label_convert_source, \ - segmentation_label_convert_target) - - score_vector = get_multi_class_evaluation_score(s_volume, g_volume, label_list, - label_fuse, s_spacing, metric ) - if(len(label_list) > 1): - score_vector.append(np.asarray(score_vector).mean()) - score_all_data.append(score_vector) - name_score_list.append([seg_name] + score_vector) - print(seg_name, score_vector) - score_all_data = np.asarray(score_all_data) - score_mean = score_all_data.mean(axis = 0) - score_std = score_all_data.std(axis = 0) - name_score_list.append(['mean'] + list(score_mean)) - name_score_list.append(['std'] + list(score_std)) + item_num = len(image_items) - # save the result as csv - score_csv = "{0:}/{1:}_{2:}_all.csv".format(seg_root_n, organ_name, metric) - with open(score_csv, mode='w') as csv_file: - csv_writer = csv.writer(csv_file, delimiter=',', - quotechar='"',quoting=csv.QUOTE_MINIMAL) - head = ['image'] + ["class_{0:}".format(i) for i in label_list] - if(len(label_list) > 1): - head = head + ["average"] - csv_writer.writerow(head) - for item in name_score_list: - csv_writer.writerow(item) - - print("{0:} mean ".format(metric), score_mean) - print("{0:} std ".format(metric), score_std) + for seg_root_n in seg_root: # for each segmentation method + for metric in metric_list: + score_all_data = [] + name_score_list= [] + for i in range(item_num): + gt_name = image_items.iloc[i, 0] + seg_name = image_items.iloc[i, 1] + # seg_name = seg_name.replace(".nii.gz", "_pred.nii.gz") + gt_full_name = gt_root + '/' + gt_name + seg_full_name = seg_root_n + '/' + seg_name + + s_dict = load_image_as_nd_array(seg_full_name) + g_dict = load_image_as_nd_array(gt_full_name) + s_volume = s_dict["data_array"]; s_spacing = s_dict["spacing"] + g_volume = g_dict["data_array"]; g_spacing = g_dict["spacing"] + # for dim in range(len(s_spacing)): + # assert(s_spacing[dim] == g_spacing[dim]) + if((ground_truth_label_convert_source is not None) and \ + ground_truth_label_convert_target is not None): + g_volume = convert_label(g_volume, ground_truth_label_convert_source, \ + ground_truth_label_convert_target) + + if((segmentation_label_convert_source is not None) and \ + segmentation_label_convert_target is not None): + s_volume = convert_label(s_volume, segmentation_label_convert_source, \ + segmentation_label_convert_target) + + score_vector = get_multi_class_evaluation_score(s_volume, g_volume, label_list, + label_fuse, s_spacing, metric ) + if(len(label_list) > 1): + score_vector.append(np.asarray(score_vector).mean()) + score_all_data.append(score_vector) + name_score_list.append([seg_name] + score_vector) + print(seg_name, score_vector) + score_all_data = np.asarray(score_all_data) + score_mean = score_all_data.mean(axis = 0) + score_std = score_all_data.std(axis = 0) + name_score_list.append(['mean'] + list(score_mean)) + name_score_list.append(['std'] + list(score_std)) + + # save the result as csv + score_csv = "{0:}/{1:}_{2:}_all.csv".format(seg_root_n, organ_name, metric) + with open(score_csv, mode='w') as csv_file: + csv_writer = csv.writer(csv_file, delimiter=',', + quotechar='"',quoting=csv.QUOTE_MINIMAL) + head = ['image'] + ["class_{0:}".format(i) for i in label_list] + if(len(label_list) > 1): + head = head + ["average"] + csv_writer.writerow(head) + for item in name_score_list: + csv_writer.writerow(item) + + print("{0:} mean ".format(metric), score_mean) + print("{0:} std ".format(metric), score_std) def main(): """ diff --git a/pymic/util/post_process.py b/pymic/util/post_process.py index a0a9dff..a889b23 100644 --- a/pymic/util/post_process.py +++ b/pymic/util/post_process.py @@ -43,6 +43,7 @@ def __call__(self, seg): seg_c = np.asarray(seg == c, np.uint8) seg_c = get_largest_k_components(seg_c) output = output + seg_c * c + seg = output return seg PostProcessDict = { diff --git a/requirements.txt b/requirements.txt index c8cd562..6dac753 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,5 +7,5 @@ scipy>=1.3.3 SimpleITK>=2.0.0 tensorboard>=2.1.0 tensorboardX>=1.9 -torch>=1.7.1 -torchvision>=0.8.2 +torch>=1.1.12 +torchvision>=0.13.0 diff --git a/setup.py b/setup.py index ce7271b..527bdcb 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ setuptools.setup( name = 'PYMIC', - version = "0.3.0", + version = "0.3.1", author ='PyMIC Consortium', author_email = 'wguotai@gmail.com', description = description, @@ -20,6 +20,17 @@ url = 'https://github.com/HiLab-git/PyMIC', license = 'Apache 2.0', packages = setuptools.find_packages(), + install_requires=[ + "matplotlib>=3.1.2", + "numpy>=1.17.4", + "pandas>=0.25.3", + "scikit-image>=0.16.2", + "scikit-learn>=0.22", + "scipy>=1.3.3", + "SimpleITK>=2.0.0", + "tensorboard>=2.1.0", + "tensorboardX>=1.9", + ], classifiers=[ 'License :: OSI Approved :: Apache Software License', 'Programming Language :: Python', From 8928f6bdc278b31b080d0f443e414a8e0524ad03 Mon Sep 17 00:00:00 2001 From: taigw Date: Tue, 6 Dec 2022 13:11:12 +0800 Subject: [PATCH 7/8] update distance evaluation --- pymic/test/test_assd.py | 37 ++++++++++++++++++++++++++++++++++++ pymic/util/evaluation_seg.py | 26 ++++++------------------- requirements.txt | 4 ++-- setup.py | 6 +++--- 4 files changed, 48 insertions(+), 25 deletions(-) create mode 100644 pymic/test/test_assd.py diff --git a/pymic/test/test_assd.py b/pymic/test/test_assd.py new file mode 100644 index 0000000..35c1804 --- /dev/null +++ b/pymic/test/test_assd.py @@ -0,0 +1,37 @@ +from scipy import ndimage +from PIL import Image +import numpy as np +import SimpleITK as sitk +import matplotlib.pyplot as plt +from pymic.util.evaluation_seg import get_edge_points + +def test_assd_2d(): + img_name = "/home/x/projects/PyMIC_project/PyMIC_examples/PyMIC_data/JSRT/label/JPCLN001.png" + img = Image.open(img_name) + img_array = np.asarray(img) + img_edge = get_edge_points(img_array > 0) + s_dis = ndimage.distance_transform_edt(1-img_edge) + plt.subplot(1,2,1) + plt.imshow(img_edge) + plt.subplot(1,2,2) + plt.imshow(s_dis) + plt.show() + +def test_assd_3d(): + img_name = "/home/x/projects/PyMIC_project/PyMIC_examples/seg_ssl/ACDC/result/unet2d_baseline/patient001_frame01.nii.gz" + img_obj = sitk.ReadImage(img_name) + spacing = img_obj.GetSpacing() + spacing = spacing[::-1] + img_data = sitk.GetArrayFromImage(img_obj) + print(img_data.shape) + print(spacing) + img_edge = get_edge_points(img_data > 0) + s_dis = ndimage.distance_transform_edt(1-img_edge, sampling=spacing) + dis_obj = sitk.GetImageFromArray(s_dis) + dis_obj.CopyInformation(img_obj) + sitk.WriteImage(dis_obj, "test_dis.nii.gz") + + + +if __name__ == "__main__": + test_assd_3d() \ No newline at end of file diff --git a/pymic/util/evaluation_seg.py b/pymic/util/evaluation_seg.py index 06fc402..ba04a73 100644 --- a/pymic/util/evaluation_seg.py +++ b/pymic/util/evaluation_seg.py @@ -6,11 +6,7 @@ import csv import os import sys -import math import pandas as pd -import random -import GeodisTK -import configparser import numpy as np from scipy import ndimage from pymic.io.image_read_write import * @@ -90,7 +86,7 @@ def get_edge_points(img): strt = ndimage.generate_binary_structure(2,1) else: strt = ndimage.generate_binary_structure(3,1) - ero = ndimage.morphology.binary_erosion(img, strt) + ero = ndimage.binary_erosion(img, strt) edge = np.asarray(img, np.uint8) - np.asarray(ero, np.uint8) return edge @@ -114,14 +110,9 @@ def binary_hd95(s, g, spacing = None): spacing = [1.0] * image_dim else: assert(image_dim == len(spacing)) - img = np.zeros_like(s) - if(image_dim == 2): - s_dis = GeodisTK.geodesic2d_raster_scan(img, s_edge, 0.0, 2) - g_dis = GeodisTK.geodesic2d_raster_scan(img, g_edge, 0.0, 2) - elif(image_dim ==3): - s_dis = GeodisTK.geodesic3d_raster_scan(img, s_edge, spacing, 0.0, 2) - g_dis = GeodisTK.geodesic3d_raster_scan(img, g_edge, spacing, 0.0, 2) - + s_dis = ndimage.distance_transform_edt(1-s_edge, sampling = spacing) + g_dis = ndimage.distance_transform_edt(1-g_edge, sampling = spacing) + dist_list1 = s_dis[g_edge > 0] dist_list1 = sorted(dist_list1) dist1 = dist_list1[int(len(dist_list1)*0.95)] @@ -150,13 +141,8 @@ def binary_assd(s, g, spacing = None): spacing = [1.0] * image_dim else: assert(image_dim == len(spacing)) - img = np.zeros_like(s) - if(image_dim == 2): - s_dis = GeodisTK.geodesic2d_raster_scan(img, s_edge, 0.0, 2) - g_dis = GeodisTK.geodesic2d_raster_scan(img, g_edge, 0.0, 2) - elif(image_dim ==3): - s_dis = GeodisTK.geodesic3d_raster_scan(img, s_edge, spacing, 0.0, 2) - g_dis = GeodisTK.geodesic3d_raster_scan(img, g_edge, spacing, 0.0, 2) + s_dis = ndimage.distance_transform_edt(1-s_edge, sampling = spacing) + g_dis = ndimage.distance_transform_edt(1-g_edge, sampling = spacing) ns = s_edge.sum() ng = g_edge.sum() diff --git a/requirements.txt b/requirements.txt index 6dac753..49912a4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,7 @@ scikit-image>=0.16.2 scikit-learn>=0.22 scipy>=1.3.3 SimpleITK>=2.0.0 -tensorboard>=2.1.0 -tensorboardX>=1.9 +tensorboard +tensorboardX torch>=1.1.12 torchvision>=0.13.0 diff --git a/setup.py b/setup.py index 527bdcb..cdf2295 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ setuptools.setup( name = 'PYMIC', - version = "0.3.1", + version = "0.3.2.2", author ='PyMIC Consortium', author_email = 'wguotai@gmail.com', description = description, @@ -28,8 +28,8 @@ "scikit-learn>=0.22", "scipy>=1.3.3", "SimpleITK>=2.0.0", - "tensorboard>=2.1.0", - "tensorboardX>=1.9", + "tensorboard", + "tensorboardX", ], classifiers=[ 'License :: OSI Approved :: Apache Software License', From 1e7994755f5e63b41af3888ec8e280b6115f2828 Mon Sep 17 00:00:00 2001 From: taigw Date: Tue, 6 Dec 2022 16:56:26 +0800 Subject: [PATCH 8/8] update the version all automatic installation of dependencies --- README.md | 4 ++-- setup.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 1cf7b3c..d338581 100644 --- a/README.md +++ b/README.md @@ -47,10 +47,10 @@ Run the following command to install the latest released version of PyMIC: ```bash pip install PYMIC ``` -To install a specific version of PYMIC such as 0.3.0, run: +To install a specific version of PYMIC such as 0.3.1, run: ```bash -pip install PYMIC==0.3.0 +pip install PYMIC==0.3.1 ``` Alternatively, you can download the source code for the latest version. Run the following command to compile and install: diff --git a/setup.py b/setup.py index cdf2295..9a7f38b 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ setuptools.setup( name = 'PYMIC', - version = "0.3.2.2", + version = "0.3.1", author ='PyMIC Consortium', author_email = 'wguotai@gmail.com', description = description,