In [1]:
import os
import sys
from tqdm import tqdm
from tensorboardX import SummaryWriter
import shutil
import argparse
import logging
import time
import random
import numpy as np

import torch
import torch.optim as optim
from torchvision import transforms
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
from torchvision.utils import make_grid

from vnet import VNet
# from networks.vnet import VNet
from dataloaders import utils
from utils import ramps, losses
#from dataloaders.CTMSpine_sitk import CTMSpine, CTMSpine_unseg, RandomScale, RandomNoise, RandomCrop, CenterCrop, RandomRot, RandomFlip, ToTensor, TransformConsistantOperator

In [3]:
# -*- coding: utf-8 -*-
# # 说明:
# 此处的label均是onehot,最后一个通道是类别通道

import os
import torch
import numpy as np
import random
# from glob import glob
from torch.utils.data import Dataset
import h5py
import itertools
from torch.utils.data.sampler import Sampler
# import cv2
from skimage import transform

class CTMSpine(Dataset):
    """ CTM Spine Dataset """
    def __init__(self, base_dir=None, split='train', num=None, transform=None, filename="mri_norm2.h5"):
        self._base_dir = base_dir
        self.transform = transform
        self.filename = filename
        self.sample_list = []
        if split=='train':
            with open(self._base_dir+'/../train.list', 'r') as f:
                self.image_list = f.readlines()
        elif split == 'test':
            with open(self._base_dir+'/../test.list', 'r') as f:
                self.image_list = f.readlines()
        self.image_list = [item.replace('\n','') for item in self.image_list]
        if num is not None:
            self.image_list = self.image_list[:num]
        print("total {} samples".format(len(self.image_list)))

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, idx):
        image_name = self.image_list[idx]
#         image = sitk.ReadImage(self._base_dir+"/"+image_name+"/image.nii.gz")
#         label = sitk.ReadImage(self._base_dir+"/"+image_name+"/label_onehot.nii.gz")

        h5f = h5py.File(self._base_dir+"/"+image_name+"/"+self.filename, 'r')
        image = h5f['image'][:]
        label = h5f['label'][:]
        label = np.argmax(label,axis=-1)
        sample = {'image': image, 'label': label}
        if self.transform:
            sample = self.transform(sample)

        return sample


class CTMSpine_unseg(Dataset):
    """ LA Dataset """
    def __init__(self, base_dir=None, num=None, transform=None, filename="center_cut.h5"):
        self._base_dir = base_dir
        self.transform = transform
        self.filename = filename
        self.sample_list = []
        print(self._base_dir+'/../../train_unseg_centercut.list')
        with open(self._base_dir+'/../../train_unseg.list', 'r') as f:
            self.image_list = f.readlines()

        self.image_list = [item.replace('\n','') for item in self.image_list]
        if num is not None:
            self.image_list = self.image_list[:num]
        print("total {} samples".format(len(self.image_list)))

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, idx):
        #print(self.image_list)
        #print("check: ",idx,len(self.image_list))
        image_name = self.image_list[idx]
        h5f = h5py.File(self._base_dir+"/"+image_name+"/"+self.filename, 'r')
        image = h5f['image'][:]
        sample = {'image': image,'label':None}
        if self.transform:
            sample = self.transform(sample)

        return sample

class CenterCrop(object):
    def __init__(self, output_size):
        self.output_size = output_size

    def __call__(self, sample):
        image, label = sample['image'], sample['label']

        # pad the sample if necessary
        if image.shape[0] <= self.output_size[0] or image.shape[1] <= self.output_size[1] or image.shape[2] <= \
                self.output_size[2]:
            pw = max((self.output_size[0] - image.shape[0]) // 2 + 3, 0)
            ph = max((self.output_size[1] - image.shape[1]) // 2 + 3, 0)
            pd = max((self.output_size[2] - image.shape[2]) // 2 + 3, 0)
            image = np.pad(image, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0)
            if label is not None:
                import pdb
                pdb.set_trace()
                label = np.pad(label, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0)

        (w, h, d) = image.shape

        w1 = int(round((w - self.output_size[0]) / 2.))
        h1 = int(round((h - self.output_size[1]) / 2.))
        d1 = int(round((d - self.output_size[2]) / 2.))

        image = image[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]]
        if label:
            label = label[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]]

        return {'image': image, 'label': label}


class RandomCrop(object):
    """
    Crop randomly the image in a sample
    Args:
    output_size (int): Desired output size
    """

    def __init__(self, output_size):
        self.output_size = output_size

    def __call__(self, sample):
        image, label = sample['image'], sample['label']

        # pad the sample if necessary
        if image.shape[0] <= self.output_size[0] or image.shape[1] <= self.output_size[1] or image.shape[2] <= \
                self.output_size[2]:
            pw = max((self.output_size[0] - image.shape[0]) // 2 + 3, 0)
            ph = max((self.output_size[1] - image.shape[1]) // 2 + 3, 0)
            pd = max((self.output_size[2] - image.shape[2]) // 2 + 3, 0)
            image = np.pad(image, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0)
            if label is not None:
                label = np.pad(label, [(pw, pw), (ph, ph), (pd, pd)], mode='constant', constant_values=0)

        (w, h, d) = image.shape
        # if np.random.uniform() > 0.33:
        #     w1 = np.random.randint((w - self.output_size[0])//4, 3*(w - self.output_size[0])//4)
        #     h1 = np.random.randint((h - self.output_size[1])//4, 3*(h - self.output_size[1])//4)
        # else:
        w1 = np.random.randint(0, w - self.output_size[0])
        h1 = np.random.randint(0, h - self.output_size[1])
        d1 = np.random.randint(0, d - self.output_size[2])

        image = image[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]]
        if label is not None:
            label = label[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]]
        return {'image': image, 'label': label}

# +
import SimpleITK as sitk
def resample_image3D(
    image3D,
    spacing=[0.3,0.3,3],
    ratio=1.0,
    method='Linear',):
    """做插值"""
    resample = sitk.ResampleImageFilter()
    import pdb
    pdb.set_trace()
    if method == 'Linear':
        resample.SetInterpolator(sitk.sitkLinear)
    elif method == 'Nearest':
        resample.SetInterpolator(sitk.sitkNearestNeighbor)
    resample.SetOutputDirection( (1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0) )
    resample.SetOutputOrigin((0,0,0))
    resample.SetOutputSpacing( (np.array(spacing)*ratio).tolist() )
    
    newsize = np.round(np.array(image3D.shape)*ratio).astype('int').tolist() 
    resample.SetSize(newsize)
    # resample.SetDefaultPixelValue(0)
    print("image3D.shape:",image3D.shape)
    image3D = sitk.GetImageFromArray(image3D)
    image3D.SetSpacing(spacing)
    image3D.SetDirection( (1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0) )
    image3D.SetOrigin((0,0,0))
    
    newimage = resample.Execute(image3D)
    newimage = sitk.GetArrayFromImage(newimage)
#     print("newimage.shape:",newimage.shape)
    return newimage

# def resample_image(image, spacing, ratio, is_label=False):
#     # image: 3D image, format: narray
#     out_spacing = (np.array(spacing)*ratio).tolist()
#     out_size = np.round(np.array(image.shape)*ratio).astype('int').tolist() 
#     image = sitk.GetImageFromArray(image)
#     image.SetSpacing(spacing)
#     image.SetDirection( (1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0) )
#     image.SetOrigin((0,0,0))

#     resample = sitk.ResampleImageFilter()
#     resample.SetOutputSpacing(out_spacing)
#     resample.SetSize(out_size)
#     resample.SetOutputDirection( (1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0) )
#     resample.SetOutputOrigin( (0,0,0) )
#     resample.SetTransform(sitk.Transform())
#     resample.SetDefaultPixelValue(2)

#     if is_label:
#         resample.SetInterpolator(sitk.sitkNearestNeighbor)
#     else:
#         resample.SetInterpolator(sitk.sitkBSpline)
#     import pdb
#     pdb.set_trace()
    
#     out_image = resample.Execute(image) 
#     out_image = sitk.GetArrayFromImage(out_image)
#     return resample.Execute(image) 

# def resample_image(image, spacing, ratio, is_label=False):
#     # image: 3D image, format: narray
#     out_spacing = (np.array(spacing)*ratio).tolist()
#     out_size = np.round(np.array(image.shape)*ratio).astype('int').tolist() 
#     image = sitk.GetImageFromArray(image)
#     image.SetSpacing(spacing)
#     image.SetDirection( (1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0) )
#     image.SetOrigin((0,0,0))

#     resample = sitk.ResampleImageFilter()
#     resample.SetOutputSpacing(out_spacing)
#     resample.SetSize(out_size)
#     resample.SetOutputDirection( (1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0) )
#     resample.SetOutputOrigin( (0,0,0) )
#     resample.SetTransform(sitk.Transform())
#     resample.SetDefaultPixelValue(2)

#     if is_label:
#         resample.SetInterpolator(sitk.sitkNearestNeighbor)
#     else:
#         resample.SetInterpolator(sitk.sitkBSpline)
#     import pdb
#     pdb.set_trace()
    
#     out_image = resample.Execute(image) 
#     out_image = sitk.GetArrayFromImage(out_image)
#     return resample.Execute(image) 

# def resample_image(image, label, ratio):
#     sitkImage = sitk.GetImageFromArray(image, isVector=False)
#     sitklabel = sitk.GetImageFromArray(label, isVector=False)

#     itemindex = np.where(label > 0)
#     randTrans = (0,np.random.randint(-np.min(itemindex[1])/2,(image.shape[1]-np.max(itemindex[1]))/2),np.random.randint(-np.min(itemindex[0])/2,(image.shape[0]-np.max(itemindex[0]))/2))
#     translation = sitk.TranslationTransform(3, randTrans)

#     resampler = sitk.ResampleImageFilter()
#     resampler.SetReferenceImage(sitkImage)
#     resampler.SetInterpolator(sitk.sitkLinear)#sitk.sitkBSpline
#     resampler.SetDefaultPixelValue(0)
#     resampler.SetTransform(translation)

#     outimgsitk = resampler.Execute(sitkImage)
    
#     resampler.SetInterpolator(sitk.sitkNearestNeighbor)
#     outlabsitk = resampler.Execute(sitklabel)

#     outimg = sitk.GetArrayFromImage(outimgsitk)
#     outimg = outimg.astype(dtype=float)

#     outlbl = sitk.GetArrayFromImage(outlabsitk) > 0
#     outlbl = outlbl.astype(dtype=float)

#     return outimg, outlbl 

def resample_image_sitk(image_sitk, label_sitk=None, newspacing=None, out_size=None): 
    resample = sitk.ResampleImageFilter()
    resample.SetOutputDirection(image_sitk.GetDirection())
    resample.SetOutputOrigin(image_sitk.GetOrigin())
    resample.SetOutputSpacing(newspacing)
    
    if not out_size:
        out_size = np.round(np.array(image_sitk.GetSize())*np.abs(image_sitk.GetSpacing())/np.array(newspacing)).astype('int').tolist()

    resample.SetSize(out_size)
    # resample.SetDefaultPixelValue(0)
    
    resample.SetInterpolator(sitk.sitkLinear)
    out_image = resample.Execute(image_sitk)
    out_image = sitk.GetArrayFromImage(out_image).transpose((2,1,0)).astype(dtype=float)
    if label_sitk is None:
        return out_image,None
    else:
        resample.SetInterpolator(sitk.sitkNearestNeighbor)
        out_label = resample.Execute(label_sitk)
        out_label = sitk.GetArrayFromImage(out_label).transpose((2,1,0,3)).astype(dtype=float)
        return out_image, out_label



# -

class RandomScale(object):
    """
    Scale randomly the image within the scaling ratio of 0.8-1.2
    Args:
    ratio_low, ratio_high (float): Desired ratio range of random scale 
    """

    def __init__(self, ratio_low, ratio_high):
        self.ratio_low = ratio_low
        self.ratio_high = ratio_high

    def __call__(self, sample):
        image, label = sample['image'], sample['label']
        print("type: ",type(image))
        
        # rescale
        ratio = np.random.uniform(self.ratio_low, self.ratio_high)
#         image = transform.rescale(image,ratio,order=1,anti_aliasing=True,preserve_range=True,multichannel=False) 
#         image = resample_image(image, spacing=[0.3, 0.3, 3.0], ratio=ratio, is_label=False)
        image,label = resample_image_sitk(image, label, [1.0, 1.0, 1.0])
        
        assert np.unique(label).tolist() == [0,1,2], "np.unique(label):"+str(np.unique(label).tolist())
        if label is not None:
            image,label = resample_image_sitk(image, label, [1.0, 1.0, 1.0])
        else:
            image = resample_image_sitk(image, None, [1.0, 1.0, 1.0], None)
#             label = transform.rescale(label,ratio,order=0,anti_aliasing=True,preserve_range=True,multichannel=False)
            #label = resample_image3D(label,spacing=[0.3,0.3,3],ratio=ratio,method='Nearest')
#             label = resample_image(image, spacing=[0.3, 0.3, 3.0], ratio=ratio, is_label=True)
#             label = np.argmax(label,axis=-1)
#         assert np.unique(label).tolist() == [0,1,2], "np.unique(rescaled label):"+str(np.unique(label).tolist())
#         print("image.shape",image.shape,
#               "label.shape",label.shape,
#               "ratio,dsize:",ratio,dsize,
#               "np.unique(label):",np.unique(label),
#              )
        return {'image': image, 'label': label}

# +
class TransformConsistantOperator():
    """
    Crop randomly flip the dataset in a sample
    Args:
    output_size (int): Desired output size
    """
    def __init__(self, k=None, axis=None):
        if k is not None:
            self.k = k
        else:
            self.k = np.random.randint(0, 4)
        if axis is not None:
            self.axis = axis
        else:
            self.axis = np.random.randint(0, 2)
            
    def transform(self, image):
        """image could be image or mask"""
        image = image.permute(2,3,4,0,1)
        image = torch.rot90(image, self.k)#np.rot90(image, self.k)
        image = torch.flip(image, dims=[self.axis])#np.flip(image, axis=self.axis).copy()
        image = image.permute(3,4,0,1,2)

#         image = image.permute(2,3,4,0,1).cpu()
#         image = np.rot90(image, self.k)
#         image = np.flip(image, axis=self.axis).copy()
#         image = torch.from_numpy( image.transpose((3,4,0,1,2)).copy() )
        return image
    
    def inv_transform(self, image):
        """image could be image or mask"""
        image = image.permute(2,3,4,0,1)
        image = torch.flip(image, dims=[self.axis])
        image = torch.rot90(image, -self.k)
        image = image.permute(3,4,0,1,2)

#         image = image.permute(2,3,4,0,1).cpu()
#         import pdb
#         pdb.set_trace()
#         image = np.flip(image, axis=self.axis).copy()
#         image = np.rot90(image, -self.k)
#         image = torch.from_numpy( image.transpose((3,4,0,1,2)).copy() )

        return image


# -

class RandomRot(object):
    """
    Randomly rotate the dataset in a sample
    Args:
    output_size (int): Desired output size
    """

    def __call__(self, sample):
        image, label = sample['image'], sample['label']
        k = np.random.randint(0, 4)
        image = np.rot90(image, k)
        if label is not None:
            label = np.rot90(label, k)

        return {'image': image, 'label': label}

class RandomFlip(object):
    """
    Randomly flip the dataset in a sample
    Args:
    output_size (int): Desired output size
    """

    def __call__(self, sample):
        image, label = sample['image'], sample['label']
        
        flip = random.sample([True,False], 1)
        if flip:
            axis = np.random.randint(0, 2)
            image = np.flip(image, axis=axis).copy()
        if label is not None:
            if flip:
                label = np.flip(label, axis=axis).copy()

        return {'image': image, 'label': label}


class RandomNoise(object):
    def __init__(self, mu=0, sigma=0.1):
        self.mu = mu
        self.sigma = sigma

    def __call__(self, sample):
        image, label = sample['image'], sample['label']
        noise = np.clip(self.sigma * np.random.randn(image.shape[0], image.shape[1], image.shape[2]), -2*self.sigma, 2*self.sigma)
        noise = noise + self.mu
        image = image + noise
        return {'image': image, 'label': label}


class CreateOnehotLabel(object):
    def __init__(self, num_classes):
        self.num_classes = num_classes

    def __call__(self, sample):
        image, label = sample['image'], sample['label']
        onehot_label = np.zeros((self.num_classes, label.shape[0], label.shape[1], label.shape[2]), dtype=np.float32)
        for i in range(self.num_classes):
            onehot_label[i, :, :, :] = (label == i).astype(np.float32)
        return {'image': image, 'label': label,'onehot_label':onehot_label}


class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""

    def __call__(self, sample):
        image = sample['image']
        image = image.reshape(1, image.shape[0], image.shape[1], image.shape[2]).astype(np.float32)
        label = sample['label']
        
        if label is not None:
            if 'onehot_label' in sample:
                return {'image': torch.from_numpy(image), 'label': torch.from_numpy(sample['label']).long(),
                        'onehot_label': torch.from_numpy(sample['onehot_label']).long()}
            else:
                return {'image': torch.from_numpy(image), 'label': torch.from_numpy(sample['label']).long()}
        else:
            if 'onehot_label' in sample:
                return {'image': torch.from_numpy(image),
                        'onehot_label': torch.from_numpy(sample['onehot_label']).long()}
            else:
                return {'image': torch.from_numpy(image)}


class TwoStreamBatchSampler(Sampler):
    """Iterate two sets of indices

    An 'epoch' is one iteration through the primary indices.
    During the epoch, the secondary indices are iterated through
    as many times as needed.
    """
    def __init__(self, primary_indices, secondary_indices, batch_size, secondary_batch_size):
        self.primary_indices = primary_indices
        self.secondary_indices = secondary_indices
        self.secondary_batch_size = secondary_batch_size
        self.primary_batch_size = batch_size - secondary_batch_size

        assert len(self.primary_indices) >= self.primary_batch_size > 0
        assert len(self.secondary_indices) >= self.secondary_batch_size > 0

    def __iter__(self):
        primary_iter = iterate_once(self.primary_indices)
        secondary_iter = iterate_eternally(self.secondary_indices)
        return (
            primary_batch + secondary_batch
            for (primary_batch, secondary_batch)
            in zip(grouper(primary_iter, self.primary_batch_size),
                    grouper(secondary_iter, self.secondary_batch_size))
        )

    def __len__(self):
        return len(self.primary_indices) // self.primary_batch_size

def iterate_once(iterable):
    return np.random.permutation(iterable)


def iterate_eternally(indices):
    def infinite_shuffles():
        while True:
            yield np.random.permutation(indices)
    return itertools.chain.from_iterable(infinite_shuffles())


def grouper(iterable, n):
    "Collect data into fixed-length chunks or blocks"
    # grouper('ABCDEFG', 3) --> ABC DEF"
    args = [iter(iterable)] * n
    return zip(*args)


In [4]:
parser = argparse.ArgumentParser()
parser.add_argument('--root_path_labeled', type=str, default='../../data/gz_dataset/segmented')
parser.add_argument('--root_path_unlabeled', type=str, default='../../data/gz_dataset/unsegmented/success')
parser.add_argument('--exp', type=str,  default='UAMT_unlabel', help='model_name')
parser.add_argument('--max_iterations', type=int,  default=6000, help='maximum epoch number to train')
parser.add_argument('--batch_size', type=int, default=2, help='batch_size per gpu')
parser.add_argument('--labeled_bs', type=int, default=1, help='labeled_batch_size per gpu')
parser.add_argument('--base_lr', type=float,  default=0.01, help='maximum epoch number to train')
parser.add_argument('--deterministic', type=int,  default=1, help='whether use deterministic training')
parser.add_argument('--seed', type=int,  default=1337, help='random seed')
parser.add_argument('--gpu', type=str,  default='0', help='GPU to use')
### costs
parser.add_argument('--ema_decay', type=float,  default=0.99, help='ema_decay')
parser.add_argument('--consistency_type', type=str,  default="mse", help='consistency_type')
parser.add_argument('--consistency', type=float,  default=0.1, help='consistency')
parser.add_argument('--consistency_rampup', type=float,  default=40.0, help='consistency_rampup')
args = parser.parse_args(args=[])


In [5]:
labeled_train_data_path = args.root_path_labeled
unlabeled_train_data_path = args.root_path_unlabeled
snapshot_path = "../model/" + args.exp + "/"

In [6]:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
torch.cuda.empty_cache()
batch_size = args.batch_size * len(args.gpu.split(','))
print(batch_size)
max_iterations = args.max_iterations
base_lr = args.base_lr
labeled_bs = args.labeled_bs * len(args.gpu.split(','))
print(labeled_bs)

if args.deterministic:
    cudnn.benchmark = False
    cudnn.deterministic = True
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

num_classes = 4
patch_size = (128, 128, 64)#(128, 128, 64)
cls_weights = [1,5,5,15]

4
2


In [7]:
def get_current_consistency_weight(epoch):
    # Consistency ramp-up from https://arxiv.org/abs/1610.02242
    return args.consistency * ramps.sigmoid_rampup(epoch, args.consistency_rampup)

def update_ema_variables(model, ema_model, alpha, global_step):
    # Use the true average until the exponential average is more correct
    alpha = min(1 - 1 / (global_step + 1), alpha)
    for ema_param, param in zip(ema_model.parameters(), model.parameters()):
        ema_param.data.mul_(alpha).add_(1 - alpha, param.data)


In [None]:
if __name__ == "__main__":
    ## make logger file
    if not os.path.exists(snapshot_path):
        os.makedirs(snapshot_path)
    if os.path.exists(snapshot_path + '/code'):
        shutil.rmtree(snapshot_path + '/code')
    shutil.copytree('.', snapshot_path + '/code', shutil.ignore_patterns(['.git','__pycache__']))

    logging.basicConfig(filename=snapshot_path+"/log.txt", level=logging.INFO,
                        format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S')
    logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
    logging.info(str(args))

    def create_model(ema=False):
        # Network definition
        net = VNet(n_channels=1, n_classes=num_classes, normalization='batchnorm', has_dropout=True)
        model = net.cuda()
        if ema:
            for param in model.parameters():
                param.detach_()
        return model

    model = create_model()
    ema_model = create_model(ema=True)
    
    #　pytorch 的数据加载到模型的操作顺序（三板斧）：
    #    ① 创建一个 Dataset 对象
    #    ② 创建一个 DataLoader 对象
    #    ③ 循环这个 DataLoader 对象，将img, label加载到模型中进行训练
    db_train_labeled = CTMSpine(
        base_dir=labeled_train_data_path,
        split='train',
        transform = transforms.Compose([
            #RandomScale(ratio_low=0.8, ratio_high=1.2),
            RandomNoise(mu=0, sigma=0.05),
            RandomRot(),
            RandomFlip(),
            RandomCrop(patch_size),
            ToTensor(),
        ]))
    db_train_unlabeled = CTMSpine_unseg(
        base_dir=unlabeled_train_data_path,
        transform = transforms.Compose([
            #RandomScale(ratio_low=0.8, ratio_high=1.2),
            RandomNoise(mu=0, sigma=0.05),
            RandomRot(),
            RandomFlip(),
            RandomCrop(patch_size),
            ToTensor(),
        ]))#因为计算一致性损失时增加了噪声，所以不在此处加噪声
#     db_test = LAHeart(base_dir=labeled_train_data_path,
#                        split='test',
#                        transform = transforms.Compose([
#                            CenterCrop(patch_size),
#                            ToTensor()
#                        ]))
    

    def worker_init_fn(worker_id):
        random.seed(args.seed+worker_id)
    # 在linux系统中可以使用多个子进程加载数据，而在windows系统中不能。所以在windows中要将DataLoader中的num_workers设置为0或者采用默认为0的设置。
    #trainloader = DataLoader(db_train, batch_sampler=batch_sampler, num_workers=4, pin_memory=True, worker_init_fn=worker_init_fn)
    #trainloader = DataLoader(db_train_labeled, batch_sampler=batch_sampler, num_workers=0, pin_memory=True, worker_init_fn=worker_init_fn)
    
    labeled_trainloader = DataLoader(db_train_labeled, batch_size=labeled_bs, shuffle=True, num_workers=2, pin_memory=True, worker_init_fn=worker_init_fn)
    unlabeled_trainloader = DataLoader(db_train_unlabeled, batch_size=batch_size-labeled_bs, shuffle=True, num_workers=2, pin_memory=True, worker_init_fn=worker_init_fn)

    model.train()
    ema_model.train()
    optimizer = optim.SGD(model.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0001)

    if args.consistency_type == 'mse':
        consistency_criterion = losses.softmax_mse_loss
    elif args.consistency_type == 'kl':
        consistency_criterion = losses.softmax_kl_loss
    else:
        assert False, args.consistency_type

    writer = SummaryWriter(snapshot_path+'/log')
    logging.info("{} itertations per epoch".format(len(unlabeled_trainloader)))

    iter_num = 0
    max_epoch = max_iterations//len(labeled_trainloader)+1
    print("max_epoch:",max_epoch)
    lr_ = base_lr
    model.train()
    for epoch_num in tqdm(range(max_epoch), ncols=70):
        time1 = time.time()
        for i_batch, (sampled_batch_labeled, sampled_batch_unlabeled) in enumerate( zip(labeled_trainloader, unlabeled_trainloader) ):
            time2 = time.time()
            # print('fetch data cost {}'.format(time2-time1))
            labeled_volume_batch, label_batch = sampled_batch_labeled['image'], sampled_batch_labeled['label']
            unlabeled_volume_batch = sampled_batch_unlabeled['image']
            unlabeled_volume_batch = torch.cat((labeled_volume_batch,unlabeled_volume_batch),dim=0)
            labeled_volume_batch, label_batch = labeled_volume_batch.cuda(), label_batch.cuda()
            unlabeled_volume_batch = unlabeled_volume_batch.cuda()

            noise = torch.clamp(torch.randn_like(unlabeled_volume_batch) * 0.05, -0.05, 0.05)
            ema_inputs = unlabeled_volume_batch + noise
            
            outputs = model(labeled_volume_batch)
            unlabeled_outputs = model(unlabeled_volume_batch)
            with torch.no_grad():
                ema_output = ema_model(ema_inputs)
            T = 8
            volume_batch_r = unlabeled_volume_batch.repeat(2, 1, 1, 1, 1)
            stride = volume_batch_r.shape[0] // 2
            preds = torch.zeros([stride * T, num_classes, patch_size[0], patch_size[1], patch_size[2]]).cuda()
            for i in range(T//2):
                TCO = TransformConsistantOperator(k=i, axis=np.random.randint(0, 2))
                ema_inputs = volume_batch_r + torch.clamp(torch.randn_like(volume_batch_r) * 0.1, -0.2, 0.2)
                with torch.no_grad():
                    ema_inputs = TCO.transform(ema_inputs).cuda()
                    pred = ema_model(ema_inputs)
                    pred = TCO.inv_transform(pred).cuda()
                    preds[2 * stride * i:2 * stride * (i + 1)] = pred
            preds = F.softmax(preds, dim=1)
            #preds = preds.reshape(T, stride, 2, 112, 112, 80)
            preds = preds.reshape(T, stride, num_classes, patch_size[0], patch_size[1], patch_size[2])
            preds = torch.mean(preds, dim=0)  #(batch, 2, 112,112,80)
            uncertainty = -1.0*torch.sum(preds*torch.log(preds + 1e-6), dim=1, keepdim=True)

            ## calculate the loss(only for labeled samples)
            loss_seg = F.cross_entropy( outputs, label_batch, weight=torch.tensor(cls_weights,dtype=torch.float32).cuda() )
            outputs_soft = F.softmax(outputs, dim=1)
            loss_seg_dice = 0
            print('\n')
            for i in range(num_classes):
                loss_mid = losses.dice_loss(outputs_soft[:, i, :, :, :], label_batch == i )
                loss_seg_dice += loss_mid
                print('dice score (1-dice_loss): {:.3f}'.format(1-loss_mid))

#             print('dicetotal:{:.3f}'.format( loss_seg_dice))
            #loss_seg_dice = losses.dice_loss(outputs_soft[:labeled_bs, 1, :, :, :], label_batch[:labeled_bs] == 1)
            supervised_loss = 0.5*(loss_seg+loss_seg_dice)
            
            # only for unlabeled samples
            consistency_weight = get_current_consistency_weight(iter_num//150)
            consistency_dist = consistency_criterion(unlabeled_outputs, ema_output) #(batch, num_classes, 112,112,80)
            threshold = (0.75+0.25*ramps.sigmoid_rampup(iter_num, max_iterations))*np.sqrt(3)#N分类问题的最大不确定度是sqrt(N)
            mask = (uncertainty<threshold).float()
#             print("consistency_dist:",consistency_dist.item())
            asd = np.prod( list(mask.shape) )
            #print("mask:",np.sum(mask.item())/asd )
            consistency_dist = torch.sum(mask*consistency_dist)/(2*torch.sum(mask)+1e-16)
            consistency_loss = consistency_weight * consistency_dist
            loss = supervised_loss + consistency_loss
            
            
            # pytorch模型训练的三板斧
            # 一般训练神经网络，总是逃不开optimizer.zero_grad之后是loss（后面有的时候还会写forward，看你网络怎么写了）之后是是net.backward之后是optimizer.step的这个过程
            optimizer.zero_grad()#把模型中参数的梯度设为0
            loss.backward()
            optimizer.step()
            update_ema_variables(model, ema_model, args.ema_decay, iter_num)

            iter_num = iter_num + 1
            writer.add_scalar('uncertainty/mean', uncertainty[0,0].mean(), iter_num)
            writer.add_scalar('uncertainty/max', uncertainty[0,0].max(), iter_num)
            writer.add_scalar('uncertainty/min', uncertainty[0,0].min(), iter_num)
            writer.add_scalar('uncertainty/mask_per', torch.sum(mask)/mask.numel(), iter_num)
            writer.add_scalar('uncertainty/threshold', threshold, iter_num)
            writer.add_scalar('lr', lr_, iter_num)
            writer.add_scalar('loss/loss', loss, iter_num)
            writer.add_scalar('loss/loss_seg', loss_seg, iter_num)
            writer.add_scalar('loss/loss_seg_dice', loss_seg_dice, iter_num)
            writer.add_scalar('train/consistency_loss', consistency_loss, iter_num)
            writer.add_scalar('train/consistency_weight', consistency_weight, iter_num)
            writer.add_scalar('train/consistency_dist', consistency_dist, iter_num)

            logging.info('iteration %d : loss : %f, loss_seg : %f, loss_seg_dice : %f, consistency_loss : %f, cons_dist: %f, loss_weight: %f' %
                         (iter_num, 
                          loss.item(), 
                          loss_seg.item(),
                          loss_seg_dice.item(),
                          consistency_loss.item(),
                          consistency_dist.item(),
                          consistency_weight))
            if iter_num % 50 == 0:
                image = labeled_volume_batch[0, 0:1, :, :, 20:61:10].permute(3, 0, 1, 2).repeat(1, 3, 1, 1)#repeat 3是为了模拟图像的RGB三个通道
                grid_image = make_grid(image, 5, normalize=True)
                writer.add_image('train/Image', grid_image, iter_num)

                # image = outputs_soft[0, 3:4, :, :, 20:61:10].permute(3, 0, 1, 2).repeat(1, 3, 1, 1)
                image = torch.max(outputs_soft[0, :, :, :, 20:61:10], 0)[1].permute(2, 0, 1).data.cpu().numpy()
                image = utils.decode_seg_map_sequence(image)
                grid_image = make_grid(image, 5, normalize=False)
                writer.add_image('train/Predicted_label', grid_image, iter_num)

                image = label_batch[0, :, :, 20:61:10].permute(2, 0, 1)
                grid_image = make_grid(utils.decode_seg_map_sequence(image.data.cpu().numpy()), 5, normalize=False)
                writer.add_image('train/Groundtruth_label', grid_image, iter_num)

                image = uncertainty[0, 0:1, :, :, 20:61:10].permute(3, 0, 1, 2).repeat(1, 3, 1, 1)
                grid_image = make_grid(image, 5, normalize=True)
                writer.add_image('train/uncertainty', grid_image, iter_num)

                mask2 = (uncertainty > threshold).float()
                image = mask2[0, 0:1, :, :, 20:61:10].permute(3, 0, 1, 2).repeat(1, 3, 1, 1)
                grid_image = make_grid(image, 5, normalize=True)
                writer.add_image('train/mask', grid_image, iter_num)
                #####
                image = labeled_volume_batch[-1, 0:1, :, :, 20:61:10].permute(3, 0, 1, 2).repeat(1, 3, 1, 1)
                grid_image = make_grid(image, 5, normalize=True)
                writer.add_image('unlabel/Image', grid_image, iter_num)

                # image = outputs_soft[-1, 3:4, :, :, 20:61:10].permute(3, 0, 1, 2).repeat(1, 3, 1, 1)
                image = torch.max(outputs_soft[-1, :, :, :, 20:61:10], 0)[1].permute(2, 0, 1).data.cpu().numpy()
                image = utils.decode_seg_map_sequence(image)
                grid_image = make_grid(image, 5, normalize=False)
                writer.add_image('unlabel/Predicted_label', grid_image, iter_num)

                image = label_batch[-1, :, :, 20:61:10].permute(2, 0, 1)
                grid_image = make_grid(utils.decode_seg_map_sequence(image.data.cpu().numpy()), 5, normalize=False)
                writer.add_image('unlabel/Groundtruth_label', grid_image, iter_num)

            ## change lr
            if iter_num % 2500 == 0:
                lr_ = base_lr * 0.1 ** (iter_num // 2500)
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr_
            if iter_num % 1000 == 0:
                save_mode_path = os.path.join(snapshot_path, 'iter_' + str(iter_num) + '.pth')
                torch.save(model.state_dict(), save_mode_path)
                logging.info("save model to {}".format(save_mode_path))

            if iter_num >= max_iterations:
                break
            time1 = time.time()
        if iter_num >= max_iterations:
            break
    save_mode_path = os.path.join(snapshot_path, 'iter_'+str(max_iterations)+'.pth')
    torch.save(model.state_dict(), save_mode_path)
    logging.info("save model to {}".format(save_mode_path))
    writer.close()


Namespace(base_lr=0.01, batch_size=4, consistency=0.1, consistency_rampup=40.0, consistency_type='mse', deterministic=1, ema_decay=0.99, exp='UAMT_unlabel', gpu='0', labeled_bs=2, max_iterations=6000, root_path_labeled='../../data/gz_dataset/segmented', root_path_unlabeled='../../data/gz_dataset/unsegmented/success', seed=1337)
total 32 samples
../../data/gz_dataset/unsegmented/success/../../train_unseg_centercut.list
total 42 samples
21 itertations per epoch


  0%|                                         | 0/376 [00:00<?, ?it/s]

max_epoch: 376


dice score (1-dice_loss): 0.550
dice score (1-dice_loss): 0.202
dice score (1-dice_loss): 0.138
dice score (1-dice_loss): 0.111
iteration 1 : loss : 2.204197, loss_seg : 1.410044, loss_seg_dice : 2.998127, consistency_loss : 0.000111, cons_dist: 0.165281, loss_weight: 0.000674


dice score (1-dice_loss): 0.569
dice score (1-dice_loss): 0.215
dice score (1-dice_loss): 0.160
dice score (1-dice_loss): 0.095
iteration 2 : loss : 2.148988, loss_seg : 1.336917, loss_seg_dice : 2.960974, consistency_loss : 0.000042, cons_dist: 0.062557, loss_weight: 0.000674


dice score (1-dice_loss): 0.486
dice score (1-dice_loss): 0.003
dice score (1-dice_loss): 0.199
dice score (1-dice_loss): 0.005
iteration 3 : loss : 2.342353, loss_seg : 1.377737, loss_seg_dice : 3.306852, consistency_loss : 0.000058, cons_dist: 0.086155, loss_weight: 0.000674


dice score (1-dice_loss): 0.600
dice score (1-dice_loss): 0.186
dice score (1-dice_loss): 0.151
dice score (1-dice_loss): 0.150
iteration 4 : l

  0%|                               | 1/376 [01:08<7:08:39, 68.59s/it]



dice score (1-dice_loss): 0.874
dice score (1-dice_loss): 0.213
dice score (1-dice_loss): 0.466
dice score (1-dice_loss): 0.197
iteration 17 : loss : 1.663672, loss_seg : 1.077228, loss_seg_dice : 2.250080, consistency_loss : 0.000018, cons_dist: 0.026572, loss_weight: 0.000674


dice score (1-dice_loss): 0.858
dice score (1-dice_loss): 0.030
dice score (1-dice_loss): 0.517
dice score (1-dice_loss): 0.101
iteration 18 : loss : 1.584368, loss_seg : 0.675163, loss_seg_dice : 2.493521, consistency_loss : 0.000026, cons_dist: 0.038153, loss_weight: 0.000674


dice score (1-dice_loss): 0.848
dice score (1-dice_loss): 0.148
dice score (1-dice_loss): 0.485
dice score (1-dice_loss): 0.118
iteration 19 : loss : 1.646052, loss_seg : 0.889790, loss_seg_dice : 2.402240, consistency_loss : 0.000037, cons_dist: 0.055449, loss_weight: 0.000674
