In [1]:
import numpy as np 
import h5py
import os 
import random 
import itertools
import math
import sys

import torch 
import torch.nn as nn 
import torch.nn.functional as F 
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import Sampler
from torch.optim import SGD
import torchvision.transforms as transforms
from skimage.measure import label 

import torch.backends.cudnn as cudnn

In [2]:
!pip install medpy



In [3]:
from tensorboardX import SummaryWriter
import shutil
import argparse
import logging
import random
from medpy import metric
import pdb
from tqdm import tqdm

In [4]:
# Params 
class params: 
    def __init__(self):
        self.root_path = 'LA'
        self.exp = 'BCP' 
        self.model = 'VNet'
        self.pre_max_iterations = 20
        self.self_max_iteration = 10 
        self.max_samples = 80 
        self.labeled_bs = 4
        self.batch_size = 8 
        self.base_lr = 0.01 
        self.deterministic = 1 
        self.labelnum = 8 
        self.consistency = 1.0 
        self.consistency_rampup = 40.0 
        self.magnitude = 10.0 
        self.seed = 10
        self.gpu = '0'
    
        # Setting of BCP 
        self.u_weight = 0.5 
        self.mask_ratio = 2/3 

        # Setting of mixup 
        self.u_alpha = 2.0 
        self.loss_weight = 0.5


args = params()

#### 1. BaseDataset

In [5]:
import os
import h5py
from torch.utils.data import Dataset

class LAHeart(Dataset):
    def __init__(self, base_dir, split='train', transform=None, num=None):
        self._base_dir = base_dir
        self.split = split
        self.transform = transform
        self.sample_list = []
        
        # Path for train/test list
        list_file = os.path.join(self._base_dir, f"{split}.list")
        if not os.path.isfile(list_file):
            raise ValueError(f"The {split} list file is missing: {list_file}")
        
        with open(list_file, 'r') as file:
            self.sample_list = [item.strip() for item in file.readlines()]
        
        if num is not None:
            self.sample_list = self.sample_list[:num]

        print(f"Mode = {self.split}, total samples: {len(self.sample_list)}")

    def __len__(self):
        return len(self.sample_list)

    def __getitem__(self, index):
        case = self.sample_list[index]
        file_path = os.path.join(self._base_dir, f'2018LA_Seg_Training Set/{case}/mri_norm2.h5')
        
        # Load data safely
        try:
            with h5py.File(file_path, 'r') as h5f:
                image = h5f['image'][:]
                label = h5f['label'][:]
        except FileNotFoundError:
            raise FileNotFoundError(f"File not found: {file_path}")
        
        sample = {'image': image, 'label': label}
        if self.transform:
            sample = self.transform(sample)
        
        return sample


In [6]:
def random_rot_flip(image, label): 
    k = np.random.randint(0, 4, 1) 
    image = np.rot90(image, k) 
    label = np.rot90(label, k) 

    axis = np.random.randint(0, 2)
    image = np.flip(image, axis) 
    label = np.flip(label, axis) 

    return image, label 

class RandomRotFlip: 
    def __call__(self, sample): 
        image, label = sample['image'], sample['label']
        image, label = random_rot_flip(image, label) 
        sample = {'image': image, 'label': label}

        return sample 

In [7]:
class RandomCrop: 
    def __init__(self, output_size, with_sdf= False): 
        self.output_size = output_size 
        self.with_sdf = with_sdf
    
    def __call__(self, sample): 
        image, label = sample['image'], sample['label']

        if self.with_sdf: 
            sdf = sample['sdf']
        
        if label.shape[0] <= self.output_size[0] or label.shape[1] <= self.output_size[1] or label.shape[2] <= self.output_size[2]: 
            pw = max((self.output_size[0] - label.shape[0]) // 2 + 3, 0) 
            ph = max((self.output_size[1] - label.shape[1]) // 2 + 3, 0)
            pd = max((self.output_size[2] - label.shape[2]) // 2 + 3, 0)
            image = np.pad(image, [(pw, pw), (ph, ph),(pd, pd)], mode= 'constant', constant_values= 0)
            label = np.pad(label, [(pw, pw), (ph, ph), (pd, pd)], mode= 'constant', constant_values= 0) 

            if self.with_sdf: 
                sdf = np.pad(sdf, [(pw, pw), (ph, ph), (pd, pd)], mode= 'constant', constant_values= 0) 

        (w, h,d) = image.shape 
        w1 = np.random.randint(0, w - self.output_size[0])
        h1 = np.random.randint(0, h - self.output_size[1])
        d1 = np.random.randint(0, d - self.output_size[2])
    
        image = image[w1 : w1 + self.output_size[0], h1 : h1 + self.output_size[1], d1 : d1 + self.output_size[2]]
        label = label[w1: w1 + self.output_size[0], h1 : h1 + self.output_size[1], d1 : d1 + self.output_size[2]]

        if self.with_sdf: 
            sdf = sdf[w1 : w1 + self.output_size[0], h1 : h1 + self.output_size[1], d1 : d1 + self.output_size[2]]
            return {'image': image, 'label': label, 'sdf': sdf}
        else: 
            return {'image': image, 'label': label}


In [8]:
class ToTensor(object):
    def __call__(self, sample):
        image = sample['image'].copy()
        label = sample['label'].copy()
        
        # Ensure the image is reshaped correctly and avoid negative strides by copying the array
        image = image.reshape(1, image.shape[0], image.shape[1], image.shape[2]).astype(np.float32)
        
        if 'onehot_label' in sample:
            return {
                'image': torch.from_numpy(image),
                'label': torch.from_numpy(sample['label']).long(),
                'onehot_label': torch.from_numpy(sample['onehot_label']).long()
            }
        else:
            return {
                'image': torch.from_numpy(image),
                'label': torch.from_numpy(label).long()
            }


In [9]:
def iterate_once(indices): 
    """
    Permutate the iterable once 
    (permutate the labeled_idxs once)
    """
    return np.random.permutation(indices) 

def iterate_externally(indices): 
    """
    Create an infinite iterator that repeatedly permutes the indices.
    ( permutate the unlabeled_idxs to make different)
    """
    def infinite_shuffles(): 
        while True: 
            yield np.random.permutation(indices)
            
    return itertools.chain.from_iterable(infinite_shuffles())

def grouper(iterable, n): 
    args = [iter(iterable)] * n 
    return zip(*args)

In [10]:
class TwoStreamBatchSampler(Sampler):
    def __init__(self, primary_indices, secondary_indices, batch_size, secondary_batch_size):
        self.primary_indices = primary_indices
        self.secondary_indices = secondary_indices
        self.primary_batch_size = batch_size - secondary_batch_size
        self.secondary_batch_size = secondary_batch_size
        
        assert len(self.primary_indices) >= self.primary_batch_size > 0
        assert len(self.secondary_indices) >= self.secondary_batch_size > 0
        
    def __iter__(self):
        primary_iter = iterate_once(self.primary_indices)
        secondary_iter = iterate_externally(self.secondary_indices)
        return (
            primary_batch + secondary_batch
            for (primary_batch, secondary_batch)
            in zip(grouper(primary_iter, self.primary_batch_size),
                   grouper(secondary_iter, self.secondary_batch_size))
        )
        
    def __len__(self):
        return len(self.primary_indices) // self.primary_batch_size

#### 2.Loss

In [11]:
def context_mask(img, mask_ratio):
    batch_size, channel, img_x, img_y, img_z = img.shape[0], img.shape[1], img.shape[2], img.shape[3], img.shape[4]
    loss_mask = torch.ones(batch_size, img_x, img_y, img_z).cuda()
    mask = torch.ones(img_x, img_y, img_z).cuda()
    patch_pixel_x, patch_pixel_y, patch_pixel_z = int(img_x*mask_ratio), int(img_y*mask_ratio), int(img_z*mask_ratio)
    w = np.random.randint(0, 112 - patch_pixel_x)
    h = np.random.randint(0, 112 - patch_pixel_y)
    z = np.random.randint(0, 80 - patch_pixel_z)
    mask[w:w+patch_pixel_x, h:h+patch_pixel_y, z:z+patch_pixel_z] = 0
    loss_mask[:, w:w+patch_pixel_x, h:h+patch_pixel_y, z:z+patch_pixel_z] = 0
    return mask.long(), loss_mask.long()

In [12]:
def to_one_hot(tensor, nclasses): 
    """
    Input (tensor): Nx1xHxW 
    """
    assert tensor.max().item() < nclasses
    assert tensor.min().item() >= 0 

    size = list(tensor.size())
    assert size[1] == 1 
    size[1] = nclasses
    one_hot = torch.zeros(*size) 
    if tensor.is_cuda: 
        one_hot = one_hot.cuda(tensor.device) 
    one_hot = one_hot.scatter_(1, tensor, 1) 
    return one_hot 

def get_probability(logits): 
    """
    Get the probability from logitis  
    """
    size = logits.size() 
    if size[1] > 1: 
        pred = F.softmax(logits, dim= 1) 
        nclass = size[1] 
    else: 
        pred = F.sigmoid(logits) 
        pred = torch.cat([1 - pred, pred], dim= 1) 
    
    return pred, nclass


class mask_DiceLoss(nn.Module): 
    def __init__(self, nclass, class_weights = None, smooth= 1e-5): 
        super(mask_DiceLoss, self).__init__() 
        self.smooth = smooth 
        if class_weights is None: 
            self.class_weights = nn.Parameter(torch.ones((1, nclass)).type(torch.float32), requires_grad= False) 
        else: 
            class_weights = np.array(class_weights) 
            assert nclass == class_weights.shape[0] 
            self.class_weights = nn.Parameter(torch.tensor(class_weights, dtype= torch.float32), requires_grad= False) 
    
    def prob_forward(self, pred, target, mask= None): 
        size = pred.size() 
        N, nclass = size[0], size[1] 

        # N x C x H x W: convert into 2D image
        pred_one_hot = pred.view(N, nclass, -1) 
        target = target.view(N, 1, -1) 
        target_one_hot = to_one_hot(target.type(torch.long), nclass).type(torch.float32)

        # N x C x H x W
        inter = pred_one_hot * target_one_hot
        union = pred_one_hot + target_one_hot

        if mask is not None: 
            mask = mask.view(N, 1, -1) 
            inter = (inter.view(N, nclass, -1) * mask).sum(2) 
            union = (union.view(N, nclass, -1) * mask).sum(2) 
        else: 
            inter = inter.view(N, nclass, -1).sum(2) 
            union = union.view(N, nclass, -1).sum(2)
        
        dice = ( 2*inter + self.smooth ) / (union + self.smooth) 
        return 1 - dice.mean()

    def forward(self, logits,target, mask = None): 
        size = logits.size() 
        N, nclass = size[0], size[1] 

        logits = logits.view(N, nclass, -1) 
        target = target.view(N, 1, -1) 

        pred, nclass = get_probability(logits) 

        pred_one_hot = pred 
        target_one_hot = to_one_hot(target.type(torch.long), nclass).type(torch.float32) 

        inter = pred_one_hot * target_one_hot
        union = pred_one_hot + target_one_hot

        if mask is not None: 
            mask = mask.view(N, 1, -1) 
            inter = (inter.view(N, nclass, -1) * mask).sum(2)
            union = (union.view(N, nclass, -1) * mask ).sum(2) 
        else: 
            inter = inter.view(N, nclass, -1).sum(2) 
            union = union.view(N, nclass, -1).sum(2)
        
        dice = ( 2 * inter + self.smooth ) / (union + self.smooth)
        return 1 - dice.mean() 
        


In [13]:
DICE  = mask_DiceLoss(nclass= 2)
CE = nn.CrossEntropyLoss()
def mix_loss(net3_output, img_l, patch_l, mask, l_weight=1.0, u_weight=0.5, unlab=False):
    img_l, patch_l = img_l.type(torch.int64), patch_l.type(torch.int64)
    image_weight, patch_weight = l_weight, u_weight
    if unlab:
        image_weight, patch_weight = u_weight, l_weight
    patch_mask = 1 - mask
    dice_loss = DICE(net3_output, img_l, mask) * image_weight 
    dice_loss += DICE(net3_output, patch_l, patch_mask) * patch_weight
    loss_ce = image_weight * (CE(net3_output, img_l) * mask).sum() / (mask.sum() + 1e-16) 
    loss_ce += patch_weight * (CE(net3_output, patch_l) * patch_mask).sum() / (patch_mask.sum() + 1e-16)
    loss = (dice_loss + loss_ce) / 2
    return loss

#### 3. Validaion 

In [14]:
def test_single_case(model, image, stride_xy, stride_z, patch_size, num_classes=1):
    w, h, d = image.shape

    # if the size of image is less than patch_size, then padding it
    add_pad = False
    if w < patch_size[0]:
        w_pad = patch_size[0]-w
        add_pad = True
    else:
        w_pad = 0
    if h < patch_size[1]:
        h_pad = patch_size[1]-h
        add_pad = True
    else:
        h_pad = 0
    if d < patch_size[2]:
        d_pad = patch_size[2]-d
        add_pad = True
    else:
        d_pad = 0
    wl_pad, wr_pad = w_pad//2,w_pad-w_pad//2
    hl_pad, hr_pad = h_pad//2,h_pad-h_pad//2
    dl_pad, dr_pad = d_pad//2,d_pad-d_pad//2
    if add_pad:
        image = np.pad(image, [(wl_pad,wr_pad),(hl_pad,hr_pad), (dl_pad, dr_pad)], mode='constant', constant_values=0)
    ww,hh,dd = image.shape

    sx = math.ceil((ww - patch_size[0]) / stride_xy) + 1
    sy = math.ceil((hh - patch_size[1]) / stride_xy) + 1
    sz = math.ceil((dd - patch_size[2]) / stride_z) + 1
    # print("{}, {}, {}".format(sx, sy, sz))
    score_map = np.zeros((num_classes, ) + image.shape).astype(np.float32)
    cnt = np.zeros(image.shape).astype(np.float32)

    for x in range(0, sx):
        xs = min(stride_xy*x, ww-patch_size[0])
        for y in range(0, sy):
            ys = min(stride_xy * y,hh-patch_size[1])
            for z in range(0, sz):
                zs = min(stride_z * z, dd-patch_size[2])
                test_patch = image[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]]
                test_patch = np.expand_dims(np.expand_dims(test_patch,axis=0),axis=0).astype(np.float32)
                test_patch = torch.from_numpy(test_patch).cuda()

                with torch.no_grad():
                    y1, _ = model(test_patch)
                    y = F.softmax(y1, dim=1)

                y = y.cpu().data.numpy()
                y = y[0,1,:,:,:]
                score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \
                  = score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + y
                cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \
                  = cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + 1
    score_map = score_map/np.expand_dims(cnt,axis=0)
    label_map = (score_map[0]>0.5).astype(np.int32)
    if add_pad:
        label_map = label_map[wl_pad:wl_pad+w,hl_pad:hl_pad+h,dl_pad:dl_pad+d]
        score_map = score_map[:,wl_pad:wl_pad+w,hl_pad:hl_pad+h,dl_pad:dl_pad+d]
    return label_map, score_map


In [15]:
def var_all_case_LA(model, num_classes, patch_size=(112, 112, 80), stride_xy=18, stride_z=4):
    file_path = 'LA/test.list'
    with open(file_path, 'r') as f:
        image_list = f.readlines()
    
    folder_dir = 'LA/2018LA_Seg_Training Set/'
    image_list = [folder_dir + item.replace('\n', '') + "/mri_norm2.h5" for item in image_list]
    loader = tqdm(image_list)
    total_dice = 0.0
    for image_path in loader:
        h5f = h5py.File(image_path, 'r')
        image = h5f['image'][:]
        label = h5f['label'][:]
        prediction, score_map = test_single_case(model, image, stride_xy, stride_z, patch_size, num_classes=num_classes)
        if np.sum(prediction)==0:
            dice = 0
        else:
            dice = metric.binary.dc(prediction, label)
        total_dice += dice
    avg_dice = total_dice / len(image_list)
    print('average metric is {}'.format(avg_dice))
    return avg_dice

#### 4.VNet

In [16]:
import torch
from torch import nn

class ConvBlock(nn.Module):
    def __init__(self, n_stages, n_filters_in, n_filters_out, kernel_size=3, padding=1, normalization='none'):
        super(ConvBlock, self).__init__()

        ops = []
        for i in range(n_stages):
            if i==0:
                input_channel = n_filters_in
            else:
                input_channel = n_filters_out

            ops.append(nn.Conv3d(input_channel, n_filters_out, kernel_size=kernel_size, padding=padding))
            if normalization == 'batchnorm':
                ops.append(nn.BatchNorm3d(n_filters_out))
            elif normalization == 'groupnorm':
                ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out))
            elif normalization == 'instancenorm':
                ops.append(nn.InstanceNorm3d(n_filters_out))
            elif normalization != 'none':
                assert False
            ops.append(nn.ReLU(inplace=True))

        self.conv = nn.Sequential(*ops)

    def forward(self, x):
        x = self.conv(x)
        return x


class ResidualConvBlock(nn.Module):
    def __init__(self, n_stages, n_filters_in, n_filters_out, normalization='none'):
        super(ResidualConvBlock, self).__init__()

        ops = []
        for i in range(n_stages):
            if i == 0:
                input_channel = n_filters_in
            else:
                input_channel = n_filters_out

            ops.append(nn.Conv3d(input_channel, n_filters_out, 3, padding=1))
            if normalization == 'batchnorm':
                ops.append(nn.BatchNorm3d(n_filters_out))
            elif normalization == 'groupnorm':
                ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out))
            elif normalization == 'instancenorm':
                ops.append(nn.InstanceNorm3d(n_filters_out))
            elif normalization != 'none':
                assert False

            if i != n_stages-1:
                ops.append(nn.ReLU(inplace=True))

        self.conv = nn.Sequential(*ops)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = (self.conv(x) + x)
        x = self.relu(x)
        return x


class DownsamplingConvBlock(nn.Module):
    def __init__(self, n_filters_in, n_filters_out, stride=2, padding=0, normalization='none'):
        super(DownsamplingConvBlock, self).__init__()

        ops = []
        if normalization != 'none':
            ops.append(nn.Conv3d(n_filters_in, n_filters_out, stride, padding=padding, stride=stride))
            if normalization == 'batchnorm':
                ops.append(nn.BatchNorm3d(n_filters_out))
            elif normalization == 'groupnorm':
                ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out))
            elif normalization == 'instancenorm':
                ops.append(nn.InstanceNorm3d(n_filters_out))
            else:
                assert False
        else:
            ops.append(nn.Conv3d(n_filters_in, n_filters_out, stride, padding=padding, stride=stride))

        ops.append(nn.ReLU(inplace=True))

        self.conv = nn.Sequential(*ops)

    def forward(self, x):
        x = self.conv(x)
        return x


class UpsamplingDeconvBlock(nn.Module):
    def __init__(self, n_filters_in, n_filters_out, stride=2, padding=0,normalization='none'):
        super(UpsamplingDeconvBlock, self).__init__()

        ops = []
        if normalization != 'none':
            ops.append(nn.ConvTranspose3d(n_filters_in, n_filters_out, stride, padding=padding, stride=stride))
            if normalization == 'batchnorm':
                ops.append(nn.BatchNorm3d(n_filters_out))
            elif normalization == 'groupnorm':
                ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out))
            elif normalization == 'instancenorm':
                ops.append(nn.InstanceNorm3d(n_filters_out))
            else:
                assert False
        else:
            ops.append(nn.ConvTranspose3d(n_filters_in, n_filters_out, stride, padding=padding, stride=stride))

        ops.append(nn.ReLU(inplace=True))

        self.conv = nn.Sequential(*ops)

    def forward(self, x):
        x = self.conv(x)
        return x
    

class Upsampling(nn.Module):
    def __init__(self, n_filters_in, n_filters_out, stride=2, normalization='none'):
        super(Upsampling, self).__init__()

        ops = []
        ops.append(nn.Upsample(scale_factor=stride, mode="trilinear",align_corners=False))
        ops.append(nn.Conv3d(n_filters_in, n_filters_out, kernel_size=3, padding=1))
        if normalization == 'batchnorm':
            ops.append(nn.BatchNorm3d(n_filters_out))
        elif normalization == 'groupnorm':
            ops.append(nn.GroupNorm(num_groups=16, num_channels=n_filters_out))
        elif normalization == 'instancenorm':
            ops.append(nn.InstanceNorm3d(n_filters_out))
        elif normalization != 'none':
            assert False
        ops.append(nn.ReLU(inplace=True))

        self.conv = nn.Sequential(*ops)

    def forward(self, x):
        x = self.conv(x)
        return x
    
class Encoder(nn.Module):
    def __init__(self, n_channels=3, n_classes=2, n_filters=16, normalization='none', has_dropout=False, has_residual=False):
        super(Encoder, self).__init__()
        self.has_dropout = has_dropout
        convBlock = ConvBlock if not has_residual else ResidualConvBlock

        self.block_one = convBlock(1, n_channels, n_filters, normalization=normalization)
        self.block_one_dw = DownsamplingConvBlock(n_filters, 2 * n_filters, normalization=normalization)

        self.block_two = convBlock(2, n_filters * 2, n_filters * 2, normalization=normalization)
        self.block_two_dw = DownsamplingConvBlock(n_filters * 2, n_filters * 4, normalization=normalization)

        self.block_three = convBlock(3, n_filters * 4, n_filters * 4, normalization=normalization)
        self.block_three_dw = DownsamplingConvBlock(n_filters * 4, n_filters * 8, normalization=normalization)

        self.block_four = convBlock(3, n_filters * 8, n_filters * 8, normalization=normalization)
        self.block_four_dw = DownsamplingConvBlock(n_filters * 8, n_filters * 16, normalization=normalization)

        self.block_five = convBlock(3, n_filters * 16, n_filters * 16, normalization=normalization)
        
        self.dropout = nn.Dropout3d(p=0.5, inplace=False)

    def forward(self, input):
        x1 = self.block_one(input)
        x1_dw = self.block_one_dw(x1)

        x2 = self.block_two(x1_dw)
        x2_dw = self.block_two_dw(x2)

        x3 = self.block_three(x2_dw)
        x3_dw = self.block_three_dw(x3)

        x4 = self.block_four(x3_dw)
        x4_dw = self.block_four_dw(x4)

        x5 = self.block_five(x4_dw)

        if self.has_dropout:
            x5 = self.dropout(x5)

        res = [x1, x2, x3, x4, x5]
        return res


class Decoder(nn.Module):
    def __init__(self, n_channels=3, n_classes=2, n_filters=16, normalization='none', has_dropout=False, has_residual=False):
        super(Decoder, self).__init__()
        self.has_dropout = has_dropout

        convBlock = ConvBlock if not has_residual else ResidualConvBlock

        upsampling = UpsamplingDeconvBlock ## using transposed convolution

        self.block_five_up = upsampling(n_filters * 16, n_filters * 8, normalization=normalization)

        self.block_six = convBlock(3, n_filters * 8, n_filters * 8, normalization=normalization)
        self.block_six_up = upsampling(n_filters * 8, n_filters * 4, normalization=normalization)

        self.block_seven = convBlock(3, n_filters * 4, n_filters * 4, normalization=normalization)
        self.block_seven_up = upsampling(n_filters * 4, n_filters * 2, normalization=normalization)

        self.block_eight = convBlock(2, n_filters * 2, n_filters * 2, normalization=normalization)
        self.block_eight_up = upsampling(n_filters * 2, n_filters, normalization=normalization)

        self.block_nine = convBlock(1, n_filters, n_filters, normalization=normalization)
        self.out_conv = nn.Conv3d(n_filters, n_classes, 1, padding=0)
        self.dropout = nn.Dropout3d(p=0.5, inplace=False)

    def forward(self, features):
        x1 = features[0]
        x2 = features[1]
        x3 = features[2]
        x4 = features[3]
        x5 = features[4]
        
        x5_up = self.block_five_up(x5)
        x5_up = x5_up + x4

        x6 = self.block_six(x5_up)
        x6_up = self.block_six_up(x6)
        x6_up = x6_up + x3

        x7 = self.block_seven(x6_up)
        x7_up = self.block_seven_up(x7)
        x7_up = x7_up + x2

        x8 = self.block_eight(x7_up)
        x8_up = self.block_eight_up(x8)
        x8_up = x8_up + x1
        x9 = self.block_nine(x8_up)
        # x9 = F.dropout3d(x9, p=0.5, training=True)
        if self.has_dropout:
            x9 = self.dropout(x9)
        out_seg = self.out_conv(x9)
        return out_seg, x8_up
 
class VNet(nn.Module):
    def __init__(self, n_channels=3, n_classes=2, n_filters=16, normalization='none', has_dropout=False, has_residual=False):
        super(VNet, self).__init__()

        self.encoder = Encoder(n_channels, n_classes, n_filters, normalization, has_dropout, has_residual)
        self.decoder = Decoder(n_channels, n_classes, n_filters, normalization, has_dropout, has_residual)
        dim_in = 16
        feat_dim = 32
        self.pool = nn.MaxPool3d(3, stride=2)
        self.projection_head = nn.Sequential(
            nn.Linear(dim_in, feat_dim),
            nn.BatchNorm1d(feat_dim),
            nn.ReLU(inplace=True),
            nn.Linear(feat_dim, feat_dim)
        )
        self.prediction_head = nn.Sequential(
            nn.Linear(feat_dim, feat_dim),
            nn.BatchNorm1d(feat_dim),
            nn.ReLU(inplace=True),
            nn.Linear(feat_dim, feat_dim)
        )
        for class_c in range(2):
            selector = nn.Sequential(
                nn.Linear(feat_dim, feat_dim),
                nn.BatchNorm1d(feat_dim),
                nn.LeakyReLU(negative_slope=0.2, inplace=True),
                nn.Linear(feat_dim, 1)
            )
            self.__setattr__('contrastive_class_selector_' + str(class_c), selector)

        for class_c in range(2):
            selector = nn.Sequential(
                nn.Linear(feat_dim, feat_dim),
                nn.BatchNorm1d(feat_dim),
                nn.LeakyReLU(negative_slope=0.2, inplace=True),
                nn.Linear(feat_dim, 1)
            )
            self.__setattr__('contrastive_class_selector_memory' + str(class_c), selector)
        
    def forward_projection_head(self, features):
        return self.projection_head(features)

    def forward_prediction_head(self, features):
        return self.prediction_head(features)

    def forward(self, input):
        features = self.encoder(input)
        out_seg, x8_up = self.decoder(features)
        features = self.pool(features[4])
        return out_seg, features # 4, 16, 112, 112, 80

#### 5. Training process

In [17]:
def sigmoid_rampup(current, rampup_length):
    if rampup_length == 0: 
        return 1.0 
    else:
        current = np.clip(current, 0, rampup_length)
        phase = 1 - (current / rampup_length)
        return float(np.exp(-5 * phase * phase))
    
# Mean-Teacher compomnent 
def get_current_consistency_weight(epoch, args): 
    return 5 * args.consistency + sigmoid_rampup(epoch, args.consistency_rampup)

@torch.no_grad()
def update_ema_variables(model, ema_model, alpha):
    for ema_param, param in zip(ema_model.parameters(), model.parameters()):
        ema_param.data.mul_(alpha).add_((1 - alpha) * param.data)


In [18]:
def get_cut_mask(out, thres=0.5, nms=0):
    probs = F.softmax(out, 1)
    masks = (probs >= thres).type(torch.int64)
    masks = masks[:, 1, :, :].contiguous()
    if nms == 1:
        masks = LargestCC_pancreas(masks)
    return masks

def LargestCC_pancreas(segmentation):
    N = segmentation.shape[0]
    batch_list = []
    for n in range(N):
        n_prob = segmentation[n].detach().cpu().numpy()
        labels = label(n_prob)
        if labels.max() != 0:
            largestCC = labels == np.argmax(np.bincount(labels.flat)[1:])+1
        else:
            largestCC = n_prob
        batch_list.append(largestCC)
    
    return torch.Tensor(batch_list).cuda()

def save_net_opt(net, optimizer, path):
    state = {
        'net': net.state_dict(),
        'opt': optimizer.state_dict(),
    }
    torch.save(state, str(path))

def load_net_opt(net, optimizer, path):
    state = torch.load(str(path))
    net.load_state_dict(state['net'])
    optimizer.load_state_dict(state['opt'])

def load_net(net, path):
    state = torch.load(str(path))
    net.load_state_dict(state['net'])

def get_current_consistency_weight(epoch):
    return args.consistency * sigmoid_rampup(epoch, args.consistency_rampup)

In [19]:
def net_factory(net_type="Unet", in_channels=1, class_num=2, mode="train", tsne=0):
    if net_type == "VNet" and mode == "train" and tsne==0:
        net = VNet(n_channels=in_channels, n_classes=class_num, normalization='batchnorm', has_dropout=True).cuda()
        
    return net

In [20]:
train_data_path = args.root_path 
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 
pre_max_iterations = args.pre_max_iterations
self_max_iterations = args.self_max_iteration 
base_lr = args.base_lr 
CE = nn.CrossEntropyLoss(reduction= 'none')

if args.deterministic:
    cudnn.benchmark = False
    cudnn.deterministic = True
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    random.seed(args.seed)
    np.random.seed(args.seed)

patch_size = (112, 112, 80)
num_classes = 2

In [21]:
def pre_train(args, snapshot_path):
    model = net_factory(args.model, in_channels=1, class_num=num_classes, mode="train")
    db_train = LAHeart(base_dir=train_data_path, 
                       split='train',
                       transform=transforms.Compose([
                           RandomRotFlip(),
                           RandomCrop(patch_size),
                           ToTensor(),
                       ]))
    labelnum = args.labelnum
    labeled_idxs = list(range(labelnum))
    unlabeled_idxs = list(range(labelnum, args.max_samples))
    batch_sampler = TwoStreamBatchSampler(labeled_idxs, unlabeled_idxs, args.batch_size, args.batch_size-args.labeled_bs)
    sub_bs = int(args.labeled_bs/2)
    def worker_init_fn(worker_id):
        random.seed(args.seed+worker_id)
    trainloader = DataLoader(db_train, batch_sampler=batch_sampler, num_workers=4, pin_memory=True, worker_init_fn=worker_init_fn)
    optimizer = SGD(model.parameters(), lr=base_lr, momentum=0.9, weight_decay=1e-4)
    DICE = mask_DiceLoss(nclass=2)
    
    model.train()
    writer = SummaryWriter(snapshot_path + '/log')
    logging.info("{} iterations per epoch".format(len(trainloader)))
    iter_num = 0
    best_dice = 0
    max_epoch = pre_max_iterations // len(trainloader) + 1
    iterator = tqdm(range(max_epoch), ncols=70)
    for epoch_num in iterator:
        for _, sampled_batch in enumerate(trainloader):
            volume_batch, label_batch = sampled_batch['image'][:args.labeled_bs], sampled_batch['label'][:args.labeled_bs]
            volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda()
            img_a, img_b = volume_batch[:sub_bs], volume_batch[sub_bs:]
            lab_a, lab_b = label_batch[:sub_bs], label_batch[sub_bs:]
            
            with torch.no_grad():
                img_mask, loss_mask = context_mask(img_a, args.mask_ratio)
                
            # Mix Input
            volume_batch = img_a * img_mask + img_b * (1 - img_mask)
            label_batch = lab_a * img_mask + lab_b * (1 - img_mask)
            
            outputs, _ = model(volume_batch)
            loss_ce = F.cross_entropy(outputs, label_batch)
            loss_dice = DICE(outputs, label_batch)
            loss = (loss_ce + loss_dice) / 2
            
            iter_num += 1
            writer.add_scalar('pre/loss_dice', loss_dice, iter_num)
            writer.add_scalar('pre/loss_ce', loss_ce, iter_num)
            writer.add_scalar('pre/loss_all', loss, iter_num)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            logging.info('iteration %d : loss: %03f, loss_dice: %03f, loss_ce: %03f' %(iter_num, loss, loss_dice, loss_ce))

            if iter_num % 20 == 0:
                model.eval()
                dice_sample = var_all_case_LA(model, num_classes=num_classes, patch_size=patch_size, stride_xy=18, stride_z=4)
                if dice_sample > best_dice:
                    best_dice = round(dice_sample, 4)
                    save_mode_path = os.path.join(snapshot_path,  'iter_{}_dice_{}.pth'.format(iter_num, best_dice))
                    save_best_path = os.path.join(snapshot_path,'{}_best_model.pth'.format(args.model))
                    save_net_opt(model, optimizer, save_mode_path)
                    save_net_opt(model, optimizer, save_best_path)
                    # torch.save(model.state_dict(), save_mode_path)
                    # torch.save(model.state_dict(), save_best_path)
                    logging.info("save best model to {}".format(save_mode_path))
                writer.add_scalar('4_Var_dice/Dice', dice_sample, iter_num)
                writer.add_scalar('4_Var_dice/Best_dice', best_dice, iter_num)
                model.train()

            
            if iter_num >= pre_max_iterations:
                break

        if iter_num >= pre_max_iterations:
            iterator.close()
            break
    writer.close()

In [29]:
def self_train(args, pre_snapshot_path, self_snapshot_path):
    model = net_factory(net_type=args.model, in_channels=1, class_num=num_classes, mode="train")
    ema_model = net_factory(net_type=args.model, in_channels=1, class_num=num_classes, mode="train")
    for param in ema_model.parameters():
            param.detach_()   # ema_model set
    db_train = LAHeart(base_dir=train_data_path,
                       split='train',
                       transform = transforms.Compose([
                          RandomRotFlip(),
                          RandomCrop(patch_size),
                          ToTensor(),
                          ]))
    labelnum = args.labelnum
    labeled_idxs = list(range(labelnum))
    unlabeled_idxs = list(range(labelnum, args.max_samples))
    batch_sampler = TwoStreamBatchSampler(labeled_idxs, unlabeled_idxs, args.batch_size, args.batch_size-args.labeled_bs)
    sub_bs = int(args.labeled_bs/2)
    def worker_init_fn(worker_id):
        random.seed(args.seed+worker_id)
    trainloader = DataLoader(db_train, batch_sampler=batch_sampler, num_workers=4, pin_memory=True, worker_init_fn=worker_init_fn)
    optimizer = SGD(model.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0001)

    pretrained_model = os.path.join(pre_snapshot_path, f'{args.model}_best_model.pth')
    load_net(model, pretrained_model)
    load_net(ema_model, pretrained_model)
    
    model.train()
    ema_model.train()
    writer = SummaryWriter(self_snapshot_path+'/log')
    logging.info("{} iterations per epoch".format(len(trainloader)))
    iter_num = 0
    best_dice = 0
    max_epoch = self_max_iterations // len(trainloader) + 1
    lr_ = base_lr
    iterator = tqdm(range(max_epoch), ncols=70)
    for epoch in iterator:
        for _, sampled_batch in enumerate(trainloader):
            volume_batch, label_batch = sampled_batch['image'], sampled_batch['label']
            volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda()
            img_a, img_b = volume_batch[:sub_bs], volume_batch[sub_bs:args.labeled_bs]
            lab_a, lab_b = label_batch[:sub_bs], label_batch[sub_bs:args.labeled_bs]
            unimg_a, unimg_b = volume_batch[args.labeled_bs:args.labeled_bs+sub_bs], volume_batch[args.labeled_bs+sub_bs:]
            with torch.no_grad():
                unoutput_a, _ = ema_model(unimg_a)
                unoutput_b, _ = ema_model(unimg_b)
                plab_a = get_cut_mask(unoutput_a, nms=1)
                plab_b = get_cut_mask(unoutput_b, nms=1)
                img_mask, loss_mask = context_mask(img_a, args.mask_ratio)
            consistency_weight = get_current_consistency_weight(iter_num // 150)

            mixl_img = img_a * img_mask + unimg_a * (1 - img_mask)
            mixu_img = unimg_b * img_mask + img_b * (1 - img_mask)
            mixl_lab = lab_a * img_mask + plab_a * (1 - img_mask)
            mixu_lab = plab_b * img_mask + lab_b * (1 - img_mask)
            outputs_l, _ = model(mixl_img)
            outputs_u, _ = model(mixu_img)
            loss_l = mix_loss(outputs_l, lab_a, plab_a, loss_mask, u_weight=args.u_weight)
            loss_u = mix_loss(outputs_u, plab_b, lab_b, loss_mask, u_weight=args.u_weight, unlab=True)

            loss = loss_l + loss_u

            iter_num += 1
            writer.add_scalar('Self/consistency', consistency_weight, iter_num)
            writer.add_scalar('Self/loss_l', loss_l, iter_num)
            writer.add_scalar('Self/loss_u', loss_u, iter_num)
            writer.add_scalar('Self/loss_all', loss, iter_num)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            logging.info('iteration %d : loss: %03f, loss_l: %03f, loss_u: %03f'%(iter_num, loss, loss_l, loss_u))

            update_ema_variables(model, ema_model, 0.99)

             # change lr
            if iter_num % 2500 == 0:
                lr_ = base_lr * 0.1 ** (iter_num // 2500)
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr_

            if iter_num % 200 == 0:
                model.eval()
                dice_sample = var_all_case_LA(model, num_classes=num_classes, patch_size=patch_size, stride_xy=18, stride_z=4)
                if dice_sample > best_dice:
                    best_dice = round(dice_sample, 4)
                    save_mode_path = os.path.join(self_snapshot_path,  'iter_{}_dice_{}.pth'.format(iter_num, best_dice))
                    save_best_path = os.path.join(self_snapshot_path,'{}_best_model.pth'.format(args.model))
                    # save_net_opt(model, optimizer, save_mode_path)
                    # save_net_opt(model, optimizer, save_best_path)
                    torch.save(model.state_dict(), save_mode_path)
                    torch.save(model.state_dict(), save_best_path)
                    logging.info("save best model to {}".format(save_mode_path))
                writer.add_scalar('4_Var_dice/Dice', dice_sample, iter_num)
                writer.add_scalar('4_Var_dice/Best_dice', best_dice, iter_num)
                model.train()
            
            if iter_num % 200 == 1:
                ins_width = 2
                B,C,H,W,D = outputs_l.size()
                snapshot_img = torch.zeros(size = (D, 3, 3*H + 3 * ins_width, W + ins_width), dtype = torch.float32)

                snapshot_img[:,:, H:H+ ins_width,:] = 1
                snapshot_img[:,:, 2*H + ins_width:2*H + 2*ins_width,:] = 1
                snapshot_img[:,:, 3*H + 2*ins_width:3*H + 3*ins_width,:] = 1
                snapshot_img[:,:, :,W:W+ins_width] = 1

                outputs_l_soft = F.softmax(outputs_l, dim=1)
                seg_out = outputs_l_soft[0,1,...].permute(2,0,1) # y
                target =  mixl_lab[0,...].permute(2,0,1)
                train_img = mixl_img[0,0,...].permute(2,0,1)

                snapshot_img[:, 0,:H,:W] = (train_img-torch.min(train_img))/(torch.max(train_img)-torch.min(train_img))
                snapshot_img[:, 1,:H,:W] = (train_img-torch.min(train_img))/(torch.max(train_img)-torch.min(train_img))
                snapshot_img[:, 2,:H,:W] = (train_img-torch.min(train_img))/(torch.max(train_img)-torch.min(train_img))

                snapshot_img[:, 0, H+ ins_width:2*H+ ins_width,:W] = target
                snapshot_img[:, 1, H+ ins_width:2*H+ ins_width,:W] = target
                snapshot_img[:, 2, H+ ins_width:2*H+ ins_width,:W] = target

                snapshot_img[:, 0, 2*H+ 2*ins_width:3*H+ 2*ins_width,:W] = seg_out
                snapshot_img[:, 1, 2*H+ 2*ins_width:3*H+ 2*ins_width,:W] = seg_out
                snapshot_img[:, 2, 2*H+ 2*ins_width:3*H+ 2*ins_width,:W] = seg_out
                
                writer.add_images('Epoch_%d_Iter_%d_labeled'% (epoch, iter_num), snapshot_img)

                outputs_u_soft = F.softmax(outputs_u, dim=1)
                seg_out = outputs_u_soft[0,1,...].permute(2,0,1) # y
                target =  mixu_lab[0,...].permute(2,0,1)
                train_img = mixu_img[0,0,...].permute(2,0,1)

                snapshot_img[:, 0,:H,:W] = (train_img-torch.min(train_img))/(torch.max(train_img)-torch.min(train_img))
                snapshot_img[:, 1,:H,:W] = (train_img-torch.min(train_img))/(torch.max(train_img)-torch.min(train_img))
                snapshot_img[:, 2,:H,:W] = (train_img-torch.min(train_img))/(torch.max(train_img)-torch.min(train_img))

                snapshot_img[:, 0, H+ ins_width:2*H+ ins_width,:W] = target
                snapshot_img[:, 1, H+ ins_width:2*H+ ins_width,:W] = target
                snapshot_img[:, 2, H+ ins_width:2*H+ ins_width,:W] = target

                snapshot_img[:, 0, 2*H+ 2*ins_width:3*H+ 2*ins_width,:W] = seg_out
                snapshot_img[:, 1, 2*H+ 2*ins_width:3*H+ 2*ins_width,:W] = seg_out
                snapshot_img[:, 2, 2*H+ 2*ins_width:3*H+ 2*ins_width,:W] = seg_out

                writer.add_images('Epoch_%d_Iter_%d_unlabel'% (epoch, iter_num), snapshot_img)

            if iter_num >= self_max_iterations:
                break

        if iter_num >= self_max_iterations:
            iterator.close()
            break
    writer.close()

In [26]:
if args.deterministic:
    cudnn.benchmark = False
    cudnn.deterministic = True
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    random.seed(args.seed)
    np.random.seed(args.seed)

patch_size = (112, 112, 80)
num_classes = 2
## make logger file
pre_snapshot_path = "./model/BCP/LA_{}_{}_labeled/pre_train".format(args.exp, args.labelnum)
self_snapshot_path = "./model/BCP/LA_{}_{}_labeled/self_train".format(args.exp, args.labelnum)
print("Starting BCP training.")
for snapshot_path in [pre_snapshot_path, self_snapshot_path]:
    if not os.path.exists(snapshot_path):
        os.makedirs(snapshot_path)
# -- Pre-Training
logging.basicConfig(filename=pre_snapshot_path+"/log.txt", level=logging.INFO, format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S')
logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
logging.info(str(args))
pre_train(args, pre_snapshot_path)

Starting BCP training.
<__main__.params object at 0x7fa3f3728fb0>
<__main__.params object at 0x7fa3f3728fb0>
<__main__.params object at 0x7fa3f3728fb0>
Mode = train, total samples: 80
2 iterations per epoch
2 iterations per epoch
2 iterations per epoch


  0%|                                          | 0/11 [00:00<?, ?it/s]

iteration 1 : loss: 0.680328, loss_dice: 0.582469, loss_ce: 0.778187
iteration 1 : loss: 0.680328, loss_dice: 0.582469, loss_ce: 0.778187
iteration 1 : loss: 0.680328, loss_dice: 0.582469, loss_ce: 0.778187
iteration 2 : loss: 0.722332, loss_dice: 0.591725, loss_ce: 0.852940
iteration 2 : loss: 0.722332, loss_dice: 0.591725, loss_ce: 0.852940
iteration 2 : loss: 0.722332, loss_dice: 0.591725, loss_ce: 0.852940


  9%|███                               | 1/11 [00:00<00:09,  1.00it/s]

iteration 3 : loss: 0.767395, loss_dice: 0.611944, loss_ce: 0.922847
iteration 3 : loss: 0.767395, loss_dice: 0.611944, loss_ce: 0.922847
iteration 3 : loss: 0.767395, loss_dice: 0.611944, loss_ce: 0.922847
iteration 4 : loss: 0.709416, loss_dice: 0.590961, loss_ce: 0.827872
iteration 4 : loss: 0.709416, loss_dice: 0.590961, loss_ce: 0.827872
iteration 4 : loss: 0.709416, loss_dice: 0.590961, loss_ce: 0.827872


 18%|██████▏                           | 2/11 [00:01<00:08,  1.00it/s]

iteration 5 : loss: 0.656469, loss_dice: 0.576009, loss_ce: 0.736929
iteration 5 : loss: 0.656469, loss_dice: 0.576009, loss_ce: 0.736929
iteration 5 : loss: 0.656469, loss_dice: 0.576009, loss_ce: 0.736929
iteration 6 : loss: 0.583871, loss_dice: 0.547771, loss_ce: 0.619971
iteration 6 : loss: 0.583871, loss_dice: 0.547771, loss_ce: 0.619971
iteration 6 : loss: 0.583871, loss_dice: 0.547771, loss_ce: 0.619971


 27%|█████████▎                        | 3/11 [00:03<00:08,  1.01s/it]

iteration 7 : loss: 0.600334, loss_dice: 0.544235, loss_ce: 0.656432
iteration 7 : loss: 0.600334, loss_dice: 0.544235, loss_ce: 0.656432
iteration 7 : loss: 0.600334, loss_dice: 0.544235, loss_ce: 0.656432
iteration 8 : loss: 0.600714, loss_dice: 0.548495, loss_ce: 0.652932
iteration 8 : loss: 0.600714, loss_dice: 0.548495, loss_ce: 0.652932
iteration 8 : loss: 0.600714, loss_dice: 0.548495, loss_ce: 0.652932


 36%|████████████▎                     | 4/11 [00:04<00:07,  1.00s/it]

iteration 9 : loss: 0.517795, loss_dice: 0.526817, loss_ce: 0.508774
iteration 9 : loss: 0.517795, loss_dice: 0.526817, loss_ce: 0.508774
iteration 9 : loss: 0.517795, loss_dice: 0.526817, loss_ce: 0.508774
iteration 10 : loss: 0.526468, loss_dice: 0.489687, loss_ce: 0.563248
iteration 10 : loss: 0.526468, loss_dice: 0.489687, loss_ce: 0.563248
iteration 10 : loss: 0.526468, loss_dice: 0.489687, loss_ce: 0.563248


 45%|███████████████▍                  | 5/11 [00:04<00:05,  1.02it/s]

iteration 11 : loss: 0.522497, loss_dice: 0.504972, loss_ce: 0.540022
iteration 11 : loss: 0.522497, loss_dice: 0.504972, loss_ce: 0.540022
iteration 11 : loss: 0.522497, loss_dice: 0.504972, loss_ce: 0.540022
iteration 12 : loss: 0.529663, loss_dice: 0.532648, loss_ce: 0.526678
iteration 12 : loss: 0.529663, loss_dice: 0.532648, loss_ce: 0.526678
iteration 12 : loss: 0.529663, loss_dice: 0.532648, loss_ce: 0.526678


 55%|██████████████████▌               | 6/11 [00:05<00:04,  1.01it/s]

iteration 13 : loss: 0.493981, loss_dice: 0.483166, loss_ce: 0.504795
iteration 13 : loss: 0.493981, loss_dice: 0.483166, loss_ce: 0.504795
iteration 13 : loss: 0.493981, loss_dice: 0.483166, loss_ce: 0.504795
iteration 14 : loss: 0.472644, loss_dice: 0.478000, loss_ce: 0.467288
iteration 14 : loss: 0.472644, loss_dice: 0.478000, loss_ce: 0.467288
iteration 14 : loss: 0.472644, loss_dice: 0.478000, loss_ce: 0.467288


 64%|█████████████████████▋            | 7/11 [00:06<00:03,  1.01it/s]

iteration 15 : loss: 0.490905, loss_dice: 0.519590, loss_ce: 0.462219
iteration 15 : loss: 0.490905, loss_dice: 0.519590, loss_ce: 0.462219
iteration 15 : loss: 0.490905, loss_dice: 0.519590, loss_ce: 0.462219
iteration 16 : loss: 0.390100, loss_dice: 0.437379, loss_ce: 0.342821
iteration 16 : loss: 0.390100, loss_dice: 0.437379, loss_ce: 0.342821
iteration 16 : loss: 0.390100, loss_dice: 0.437379, loss_ce: 0.342821


 73%|████████████████████████▋         | 8/11 [00:07<00:02,  1.01it/s]

iteration 17 : loss: 0.442462, loss_dice: 0.484468, loss_ce: 0.400456
iteration 17 : loss: 0.442462, loss_dice: 0.484468, loss_ce: 0.400456
iteration 17 : loss: 0.442462, loss_dice: 0.484468, loss_ce: 0.400456
iteration 18 : loss: 0.471986, loss_dice: 0.502863, loss_ce: 0.441109
iteration 18 : loss: 0.471986, loss_dice: 0.502863, loss_ce: 0.441109
iteration 18 : loss: 0.471986, loss_dice: 0.502863, loss_ce: 0.441109


 82%|███████████████████████████▊      | 9/11 [00:08<00:01,  1.01it/s]

iteration 19 : loss: 0.467744, loss_dice: 0.462164, loss_ce: 0.473325
iteration 19 : loss: 0.467744, loss_dice: 0.462164, loss_ce: 0.473325
iteration 19 : loss: 0.467744, loss_dice: 0.462164, loss_ce: 0.473325
iteration 20 : loss: 0.460289, loss_dice: 0.493625, loss_ce: 0.426954
iteration 20 : loss: 0.460289, loss_dice: 0.493625, loss_ce: 0.426954
iteration 20 : loss: 0.460289, loss_dice: 0.493625, loss_ce: 0.426954


100%|██████████| 20/20 [00:35<00:00,  1.76s/it]


average metric is 0.07305181710129968
save best model to ./model/BCP/LA_BCP_8_labeled/pre_train/iter_20_dice_0.0731.pth
save best model to ./model/BCP/LA_BCP_8_labeled/pre_train/iter_20_dice_0.0731.pth
save best model to ./model/BCP/LA_BCP_8_labeled/pre_train/iter_20_dice_0.0731.pth


 82%|███████████████████████████▊      | 9/11 [00:45<00:10,  5.05s/it]


In [30]:
# -- Self-training
logging.basicConfig(filename=self_snapshot_path+"/log.txt", level=logging.INFO, format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S')
logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
logging.info(str(args))
self_train(args, pre_snapshot_path, self_snapshot_path)

<__main__.params object at 0x7fa3f3728fb0>
<__main__.params object at 0x7fa3f3728fb0>
<__main__.params object at 0x7fa3f3728fb0>
<__main__.params object at 0x7fa3f3728fb0>
<__main__.params object at 0x7fa3f3728fb0>
Mode = train, total samples: 80
2 iterations per epoch
2 iterations per epoch
2 iterations per epoch
2 iterations per epoch
2 iterations per epoch


  state = torch.load(str(path))
  return torch.Tensor(batch_list).cuda()


iteration 1 : loss: 1.385639, loss_l: 0.686397, loss_u: 0.699241
iteration 1 : loss: 1.385639, loss_l: 0.686397, loss_u: 0.699241
iteration 1 : loss: 1.385639, loss_l: 0.686397, loss_u: 0.699241
iteration 1 : loss: 1.385639, loss_l: 0.686397, loss_u: 0.699241
iteration 1 : loss: 1.385639, loss_l: 0.686397, loss_u: 0.699241
iteration 2 : loss: 1.297386, loss_l: 0.635858, loss_u: 0.661528
iteration 2 : loss: 1.297386, loss_l: 0.635858, loss_u: 0.661528
iteration 2 : loss: 1.297386, loss_l: 0.635858, loss_u: 0.661528
iteration 2 : loss: 1.297386, loss_l: 0.635858, loss_u: 0.661528
iteration 2 : loss: 1.297386, loss_l: 0.635858, loss_u: 0.661528


 17%|█████▊                             | 1/6 [00:04<00:21,  4.30s/it]

iteration 3 : loss: 1.383565, loss_l: 0.639192, loss_u: 0.744373
iteration 3 : loss: 1.383565, loss_l: 0.639192, loss_u: 0.744373
iteration 3 : loss: 1.383565, loss_l: 0.639192, loss_u: 0.744373
iteration 3 : loss: 1.383565, loss_l: 0.639192, loss_u: 0.744373
iteration 3 : loss: 1.383565, loss_l: 0.639192, loss_u: 0.744373
iteration 4 : loss: 1.326058, loss_l: 0.634418, loss_u: 0.691640
iteration 4 : loss: 1.326058, loss_l: 0.634418, loss_u: 0.691640
iteration 4 : loss: 1.326058, loss_l: 0.634418, loss_u: 0.691640
iteration 4 : loss: 1.326058, loss_l: 0.634418, loss_u: 0.691640
iteration 4 : loss: 1.326058, loss_l: 0.634418, loss_u: 0.691640


 33%|███████████▋                       | 2/6 [00:07<00:14,  3.50s/it]

iteration 5 : loss: 1.346827, loss_l: 0.596793, loss_u: 0.750033
iteration 5 : loss: 1.346827, loss_l: 0.596793, loss_u: 0.750033
iteration 5 : loss: 1.346827, loss_l: 0.596793, loss_u: 0.750033
iteration 5 : loss: 1.346827, loss_l: 0.596793, loss_u: 0.750033
iteration 5 : loss: 1.346827, loss_l: 0.596793, loss_u: 0.750033
iteration 6 : loss: 1.242440, loss_l: 0.624348, loss_u: 0.618091
iteration 6 : loss: 1.242440, loss_l: 0.624348, loss_u: 0.618091
iteration 6 : loss: 1.242440, loss_l: 0.624348, loss_u: 0.618091
iteration 6 : loss: 1.242440, loss_l: 0.624348, loss_u: 0.618091
iteration 6 : loss: 1.242440, loss_l: 0.624348, loss_u: 0.618091


 50%|█████████████████▌                 | 3/6 [00:10<00:09,  3.32s/it]

iteration 7 : loss: 1.395901, loss_l: 0.677066, loss_u: 0.718835
iteration 7 : loss: 1.395901, loss_l: 0.677066, loss_u: 0.718835
iteration 7 : loss: 1.395901, loss_l: 0.677066, loss_u: 0.718835
iteration 7 : loss: 1.395901, loss_l: 0.677066, loss_u: 0.718835
iteration 7 : loss: 1.395901, loss_l: 0.677066, loss_u: 0.718835
iteration 8 : loss: 1.320524, loss_l: 0.648389, loss_u: 0.672135
iteration 8 : loss: 1.320524, loss_l: 0.648389, loss_u: 0.672135
iteration 8 : loss: 1.320524, loss_l: 0.648389, loss_u: 0.672135
iteration 8 : loss: 1.320524, loss_l: 0.648389, loss_u: 0.672135
iteration 8 : loss: 1.320524, loss_l: 0.648389, loss_u: 0.672135


 67%|███████████████████████▎           | 4/6 [00:13<00:06,  3.20s/it]

iteration 9 : loss: 1.213508, loss_l: 0.614705, loss_u: 0.598802
iteration 9 : loss: 1.213508, loss_l: 0.614705, loss_u: 0.598802
iteration 9 : loss: 1.213508, loss_l: 0.614705, loss_u: 0.598802
iteration 9 : loss: 1.213508, loss_l: 0.614705, loss_u: 0.598802
iteration 9 : loss: 1.213508, loss_l: 0.614705, loss_u: 0.598802
iteration 10 : loss: 1.365941, loss_l: 0.709458, loss_u: 0.656484
iteration 10 : loss: 1.365941, loss_l: 0.709458, loss_u: 0.656484
iteration 10 : loss: 1.365941, loss_l: 0.709458, loss_u: 0.656484
iteration 10 : loss: 1.365941, loss_l: 0.709458, loss_u: 0.656484
iteration 10 : loss: 1.365941, loss_l: 0.709458, loss_u: 0.656484


 67%|███████████████████████▎           | 4/6 [00:16<00:08,  4.14s/it]
