In [1]:
!pip install medpy



In [1]:
import numpy as np 
import pandas as pd 
import matplotlib.pyplot as plt 
import h5py
import os 
from scipy.ndimage import zoom, rotate
import random 
# torch 
import torch 
from torch.utils.data import Dataset 
from torch.utils.data import DataLoader
from torch.utils.data import Sampler
import torch.optim as optim 
import sys

import itertools
import torchvision.transforms as transforms 
import torch.nn as nn 
import torch.nn.functional as F 
from medpy import metric 
from skimage.measure import label 
import logging
from tqdm import tqdm
import torch.backends.cudnn as cudnn
from torch.nn import CrossEntropyLoss
from tensorboardX import SummaryWriter

#### 1. ACDC Dataset

In [3]:
class BaseDataset(Dataset): 
    def __init__(self, root_path, split= 'train', transform= None, num= None): 
        """
        Use to load name of file.list 

        """
        self.root_path = root_path
        self.split = split 
        self.transform = transform
        self.sample_list = []

        if split == 'train': 
            file_name = os.path.join(self.root_path, 'train_slices.list')
            with open(file_name, 'r') as file: 
                self.sample_list = file.readlines() 
            
            self.sample_list = [item.replace('\n', '') for item in self.sample_list]
            # print(self.sample_list)
        elif split == 'val': 
            file_name = os.path.join(self.root_path, 'val.list')
            with open(file_name, 'r') as file: 
                self.sample_list = file.readlines() 
            
            self.sample_list = [item.replace('\n','') for item in self.sample_list]
            # print(self.sample_list)
        
        else: 
            raise ValueError(f'Mode: {self.split} is not supported')
        

        if isinstance(num, int): 
            self.sample_list = self.sample_list[:num]
        
        print(f'Total slices: {len(self.sample_list)}')
        

    def __len__(self): 
        return len(self.sample_list)

    def __getitem__(self, index):
        case = self.sample_list[index]

        # open the file.h5 
        if self.split == 'train': 
            file_path = os.path.join(self.root_path, f'data/slices/{case}.h5')
        elif self.split == 'val': 
            file_path = os.path.join(self.root_path, f'data/{case}.h5')
        
        h5f = h5py.File(file_path, 'r')
        image = h5f['image'][:]
        label = h5f['label'][:]

        sample = {'image': image, 'label': label}
        if self.split == 'train' and self.transform is not None: 
            sample = self.transform(sample)
        
        sample['case'] = case 
        return sample 

In [4]:
def random_rot_flip(image, label): 
    """
    Random rotate and Random flip 
    """
    
    # Random rotate
    k = np.random.randint(0, 4) 
    image = np.rot90(image, k)
    label = np.rot90(label, k)

    # Random flip 
    axis = np.random.randint(0, 2)
    image = np.flip(image, axis).copy() 
    label = np.flip(label, axis).copy() 

    return image, label 

def random_rotate(image, label):
    angle = np.random.randint(-20, 20) 
    image = rotate(image, angle, order= 0, reshape= False)
    label = rotate(label,angle, order=0, reshape= False )
    return image, label

class RandomGenerator: 
    def __init__(self, output_size): 
        self.output_size = output_size
    
    def __call__(self, sample):
        image, label = sample['image'], sample['label']
        if np.random.random() > 0.5: 
            image, label = random_rot_flip(image, label)
        
        if np.random.random() > 0.5: 
            image, label = random_rotate(image, label) 
        
        # Zoom image to -> [256,256]
        x,y = image.shape
        image = zoom(image, (self.output_size[0] / x, self.output_size[1] / y), order= 0)
        label = zoom(label, (self.output_size[0] /x , self.output_size[1] / y), order= 0)

        # Convert to pytorch 
        imageTensor = torch.from_numpy(image.astype(np.float32)).unsqueeze(0) # image.shape = (1, H, W)
        labelTensor = torch.from_numpy(label.astype(np.uint8)) # label.shape = (H, W)
        sample = {'image': imageTensor, 'label': labelTensor}
        
        return sample
    

In [5]:
def iterate_once(indices): 
    """
    Permutate the iterable once 
    (permutate the labeled_idxs once)
    """
    return np.random.permutation(indices) 

def iterate_externally(indices): 
    """
    Create an infinite iterator that repeatedly permutes the indices.
    ( permutate the unlabeled_idxs to make different)
    """
    def infinite_shuffles(): 
        while True: 
            yield np.random.permutation(indices)
            
    return itertools.chain.from_iterable(infinite_shuffles())

def grouper(iterable, n): 
    args = [iter(iterable)] * n 
    return zip(*args)

class TwoStreamBatchSampler(Sampler): 
    def __init__(self, primary_indicies, secondary_indicies, batchsize, secondary_batchsize): 
        self.primary_indicies = primary_indicies
        self.secondary_indicies = secondary_indicies
        self.primary_batchsize = batchsize - secondary_batchsize
        self.secondary_batchsize = secondary_batchsize

        assert len(self.primary_indicies) >= self.primary_batchsize > 0 
        assert len(self.secondary_indicies) >= self.secondary_batchsize > 0 

    def __iter__(self):
        primary_iter = iterate_once(self.primary_indicies)
        secondary_iter = iterate_externally(self.secondary_indicies)

        return (
            primary_batch + secondary_batch
            for (primary_batch, secondary_batch) 
            in zip(grouper(primary_iter, self.primary_batchsize),
                   grouper(secondary_iter, self.secondary_batchsize))
        )

    def __len__(self): 
        return len(self.primary_indicies) // self.primary_batchsize

In [6]:
# Split the data 
def patients_to_slices(dataset, patients_num): 
    ref_dict = {} 
    if "ACDC" in dataset: 
        ref_dict = {'1': 32, '3': 68, '7': 136, '14': 256, '21': 396, '28': 512, '35': 664, '70': 1312}
    else:
        print('Error')
    
    return ref_dict[str(patients_num)]

#### 2. Loss 

In [7]:
class DiceLoss(nn.Module): 
    def __init__(self, n_classes): 
        super(DiceLoss, self).__init__() 
        self.n_classes = n_classes
    
    def _one_hot_encoder(self, input_tensor): # torch.nn.functional.one_hot()
        """
        Apply one-hot encoder for input_tensor 
        Parameters: 
            - input_tensor.shape = (batchsize,1, H, W), the target image
        """
        tensor_list = [] 
        for i in range(self.n_classes): 
            temp_prob = input_tensor == i * torch.ones_like(input_tensor)
            tensor_list.append(temp_prob)
        output_tensor = torch.cat(tensor_list, dim= 1)
        return output_tensor.float() 
    
    def _dice_loss(self, score, target): 
        target = target.float() 
        smooth = 1e-10 
        
        intersection = torch.sum(score * target)
        union = torch.sum(score* score) + torch.sum(target*target)
        dice = ( 2*intersection + smooth) / (union + smooth)
        loss = 1 - dice 
        return loss 
    
    def _dice_mask_loss(self, score, target, mask): 
        target = target.float() 
        mask = mask.float() 
        smooth = 1e-10 

        intersection = torch.sum(score * target * mask)
        union = torch.sum(score * score * mask ) + torch.sum(target * target * mask)
        dice = (2*intersection + smooth) / (union + smooth)
        loss = 1 - dice 
        return loss 

    def forward(self, inputs, target, mask= None, weight= None, softmax= False): 
        if softmax: 
            inputs = torch.softmax(inputs, dim= 1) 
        
        target = self._one_hot_encoder(target)

        # weight 
        if weight is  None: 
            weight = [1] * self.n_classes
        
        assert inputs.size() == target.size(), 'predict and target shape do not match'
        class_wise_dice = [] 
        loss = 0.0 
        if mask is not None: 
            mask = mask.repeat(1, self.n_classes, 1, 1).type(torch.float32)
            for i in range(0, self.n_classes): 
                dice = self._dice_mask_loss(inputs[:, i], target[:, i], mask[:, i])
                class_wise_dice.append( 1.0 - dice.item())
                loss += dice * weight[i]

        else: 
            for i in range(0, self.n_classes): 
                dice = self._dice_loss(inputs[:, i], target[:, i]) 
                class_wise_dice.append(1.0 - dice.item())
                loss += dice * weight[i] 
        
        return loss / self.n_classes

In [8]:
dice_loss = DiceLoss(n_classes= 4)

In [9]:
def mix_loss(output, img_l, patch_l, mask, l_weight=1.0, u_weight=0.5, unlab=False):
    CE = nn.CrossEntropyLoss(reduction='none')
    img_l, patch_l = img_l.type(torch.int64), patch_l.type(torch.int64)
    output_soft = F.softmax(output, dim=1)
    image_weight, patch_weight = l_weight, u_weight
    if unlab:
        image_weight, patch_weight = u_weight, l_weight
    patch_mask = 1 - mask
    loss_dice = dice_loss(output_soft, img_l.unsqueeze(1), mask.unsqueeze(1)) * image_weight
    loss_dice += dice_loss(output_soft, patch_l.unsqueeze(1), patch_mask.unsqueeze(1)) * patch_weight
    loss_ce = image_weight * (CE(output, img_l) * mask).sum() / (mask.sum() + 1e-16) 
    loss_ce += patch_weight * (CE(output, patch_l) * patch_mask).sum() / (patch_mask.sum() + 1e-16)#loss = loss_ce
    return loss_dice, loss_ce

#### 3.Valid 


In [10]:
def calculate_metric_percase(pred, gt):
    pred[pred > 0] = 1
    gt[gt > 0] = 1
    if pred.sum() > 0:
        dice = metric.binary.dc(pred, gt)
        hd95 = metric.binary.hd95(pred, gt)
        return dice, hd95
    else:
        return 0, 0


def test_single_volume(image, label, model, classes, patch_size=[256, 256]):
    image, label = image.squeeze(0).cpu().detach().numpy(), label.squeeze(0).cpu().detach().numpy()
    prediction = np.zeros_like(label)
    for ind in range(image.shape[0]):
        slice = image[ind, :, :]
        x, y = slice.shape[0], slice.shape[1]
        slice = zoom(slice, (patch_size[0] / x, patch_size[1] / y), order=0)
        input = torch.from_numpy(slice).unsqueeze(0).unsqueeze(0).float().cuda()
        model.eval()
        with torch.no_grad():
            output = model(input)
            if len(output)>1:
                output = output[0]
            out = torch.argmax(torch.softmax(output, dim=1), dim=1).squeeze(0)
            out = out.cpu().detach().numpy()
            pred = zoom(out, (x / patch_size[0], y / patch_size[1]), order=0)
            prediction[ind] = pred
    metric_list = []
    for i in range(1, classes):
        metric_list.append(calculate_metric_percase(prediction == i, label == i))
    return metric_list


#### 3. UNet

In [11]:
from __future__ import division, print_function

import numpy as np
import torch
import torch.nn as nn
import pdb
from torch.nn import functional as F
from torch.distributions.uniform import Uniform


class ConvBlock(nn.Module):
    """two convolution layers with batch norm and leaky relu"""
    def __init__(self, in_channels, out_channels, dropout_p):
        super(ConvBlock, self).__init__()
        self.conv_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(),
            nn.Dropout(dropout_p),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU()
        )

    def forward(self, x):
        return self.conv_conv(x)

class DownBlock(nn.Module):
    """Downsampling followed by ConvBlock"""
    def __init__(self, in_channels, out_channels, dropout_p):
        super(DownBlock, self).__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            ConvBlock(in_channels, out_channels, dropout_p)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class UpBlock(nn.Module):
    """Upsampling followed by ConvBlock"""
    def __init__(self, in_channels1, in_channels2, out_channels, dropout_p):
        super(UpBlock, self).__init__()
        self.conv1x1 = nn.Conv2d(in_channels1, in_channels2, kernel_size=1)
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv = ConvBlock(in_channels2 * 2, out_channels, dropout_p)

    def forward(self, x1, x2):
        x1 = self.conv1x1(x1)
        x1 = self.up(x1)
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class Encoder(nn.Module):
    def __init__(self, params):
        super(Encoder, self).__init__()
        self.params = params
        self.in_chns = self.params['in_chns']
        self.ft_chns = self.params['feature_chns']
        self.n_class = self.params['class_num']
        self.dropout = self.params['dropout']
        assert (len(self.ft_chns) == 5)
        self.in_conv = ConvBlock(
            self.in_chns, self.ft_chns[0], self.dropout[0])
        self.down1 = DownBlock(
            self.ft_chns[0], self.ft_chns[1], self.dropout[1])
        self.down2 = DownBlock(
            self.ft_chns[1], self.ft_chns[2], self.dropout[2])
        self.down3 = DownBlock(
            self.ft_chns[2], self.ft_chns[3], self.dropout[3])
        self.down4 = DownBlock(
            self.ft_chns[3], self.ft_chns[4], self.dropout[4])

    def forward(self, x):
        x0 = self.in_conv(x)
        x1 = self.down1(x0)
        x2 = self.down2(x1)
        x3 = self.down3(x2)
        x4 = self.down4(x3)
        return [x0, x1, x2, x3, x4]

class Decoder(nn.Module):
    def __init__(self, params):
        super(Decoder, self).__init__()
        self.params = params
        self.in_chns = self.params['in_chns']
        self.ft_chns = self.params['feature_chns']
        self.n_class = self.params['class_num']
        assert (len(self.ft_chns) == 5)

        self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], dropout_p=0.0)
        self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], dropout_p=0.0)
        self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], dropout_p=0.0)
        self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], dropout_p=0.0)

        self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, kernel_size=3, padding=1)

    def forward(self, feature):
        x0 = feature[0]
        x1 = feature[1]
        x2 = feature[2]
        x3 = feature[3]
        x4 = feature[4]

        x = self.up1(x4, x3)
        x = self.up2(x, x2)
        x = self.up3(x, x1)
        x_last = self.up4(x, x0)
        output = self.out_conv(x_last)
        return output, x_last

class Decoder_tsne(nn.Module):
    def __init__(self, params):
        super(Decoder_tsne, self).__init__()
        self.params = params
        self.in_chns = self.params['in_chns']
        self.ft_chns = self.params['feature_chns']
        self.n_class = self.params['class_num']
        assert (len(self.ft_chns) == 5)

        self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], dropout_p=0.0)
        self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], dropout_p=0.0)
        self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], dropout_p=0.0)
        self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], dropout_p=0.0)

        self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, kernel_size=3, padding=1)

    def forward(self, feature):
        x0 = feature[0]
        x1 = feature[1]
        x2 = feature[2]
        x3 = feature[3]
        x4 = feature[4]

        x5 = self.up1(x4, x3) # 1, 128, 32, 32
        x6 = self.up2(x5, x2) # 1, 64, 64, 64
        x7 = self.up3(x6, x1) # 1, 32, 128, 128
        x_last = self.up4(x7, x0) # 1, 16, 256, 256
        output = self.out_conv(x_last)
        return output, x_last
    
class UNet(nn.Module):
    def __init__(self, in_chns, class_num):
        super(UNet, self).__init__()

        params = {'in_chns': in_chns,
                  'feature_chns': [16, 32, 64, 128, 256],
                  'dropout': [0.05, 0.1, 0.2, 0.3, 0.5],
                  'class_num': class_num,
                  'acti_func': 'relu'}

        self.encoder = Encoder(params)
        self.decoder = Decoder(params)
        dim_in = 16
        feat_dim = 32
        self.projection_head = nn.Sequential(
            nn.Linear(dim_in, feat_dim),
            nn.BatchNorm1d(feat_dim),
            nn.ReLU(inplace=True),
            nn.Linear(feat_dim, feat_dim)
        )
        self.prediction_head = nn.Sequential(
            nn.Linear(feat_dim, feat_dim),
            nn.BatchNorm1d(feat_dim),
            nn.ReLU(inplace=True),
            nn.Linear(feat_dim, feat_dim)
        )
        for class_c in range(4):
            selector = nn.Sequential(
                nn.Linear(feat_dim, feat_dim),
                nn.BatchNorm1d(feat_dim),
                nn.LeakyReLU(negative_slope=0.2, inplace=True),
                nn.Linear(feat_dim, 1)
            )
            self.__setattr__('contrastive_class_selector_' + str(class_c), selector)

        for class_c in range(4):
            selector = nn.Sequential(
                nn.Linear(feat_dim, feat_dim),
                nn.BatchNorm1d(feat_dim),
                nn.LeakyReLU(negative_slope=0.2, inplace=True),
                nn.Linear(feat_dim, 1)
            )
            self.__setattr__('contrastive_class_selector_memory' + str(class_c), selector)
            
    def forward_projection_head(self, features):
        return self.projection_head(features)

    def forward_prediction_head(self, features):
        return self.prediction_head(features)

    def forward(self, x):
        feature = self.encoder(x)
        output, features = self.decoder(feature)
        return output, features

class UNet_2d(nn.Module):
    def __init__(self, in_chns, class_num):
        super(UNet_2d, self).__init__()

        params = {'in_chns': in_chns,
                  'feature_chns': [16, 32, 64, 128, 256],
                  'dropout': [0.05, 0.1, 0.2, 0.3, 0.5],
                  'class_num': class_num,
                  'acti_func': 'relu'}

        self.encoder = Encoder(params)
        self.decoder = Decoder(params)
        dim_in = 16
        feat_dim = 32
        self.projection_head = nn.Sequential(
            nn.Linear(dim_in, feat_dim),
            nn.BatchNorm1d(feat_dim),
            nn.ReLU(inplace=True),
            nn.Linear(feat_dim, feat_dim)
        )
        self.prediction_head = nn.Sequential(
            nn.Linear(feat_dim, feat_dim),
            nn.BatchNorm1d(feat_dim),
            nn.ReLU(inplace=True),
            nn.Linear(feat_dim, feat_dim)
        )
        for class_c in range(4):
            selector = nn.Sequential(
                nn.Linear(feat_dim, feat_dim),
                nn.BatchNorm1d(feat_dim),
                nn.LeakyReLU(negative_slope=0.2, inplace=True),
                nn.Linear(feat_dim, 1)
            )
            self.__setattr__('contrastive_class_selector_' + str(class_c), selector)

        for class_c in range(4):
            selector = nn.Sequential(
                nn.Linear(feat_dim, feat_dim),
                nn.BatchNorm1d(feat_dim),
                nn.LeakyReLU(negative_slope=0.2, inplace=True),
                nn.Linear(feat_dim, 1)
            )
            self.__setattr__('contrastive_class_selector_memory' + str(class_c), selector)
            
    def forward_projection_head(self, features):
        return self.projection_head(features)
        #return self.decoder(features)

    def forward_prediction_head(self, features):
        return self.prediction_head(features)

    def forward(self, x):
        feature = self.encoder(x)
        output, features = self.decoder(feature)
        return output

In [12]:
import torch.nn as nn

def net_factory(net_type="unet", in_chns=1, class_num=2, mode = "train", tsne=0):
    if net_type == "unet" and mode == "train":
        net = UNet(in_chns=in_chns, class_num=class_num).cuda()
    return net

def BCP_net(in_chns=1, class_num=4, ema=False):
    net = UNet_2d(in_chns=in_chns, class_num=class_num).cuda()
    if ema:
        for param in net.parameters():
            param.detach_()
    return net


#### 4.Training 

In [13]:
def get_ACDC_2DLargestCC(segmentation):
    batch_list = []
    N = segmentation.shape[0]
    for i in range(0, N):
        class_list = []
        for c in range(1, 4):
            temp_seg = segmentation[i] #== c *  torch.ones_like(segmentation[i])
            temp_prob = torch.zeros_like(temp_seg)
            temp_prob[temp_seg == c] = 1
            temp_prob = temp_prob.detach().cpu().numpy()
            labels = label(temp_prob)          
            if labels.max() != 0:
                largestCC = labels == np.argmax(np.bincount(labels.flat)[1:])+1
                class_list.append(largestCC * c)
            else:
                class_list.append(temp_prob)
        
        n_batch = class_list[0] + class_list[1] + class_list[2]
        batch_list.append(n_batch)

    return torch.Tensor(batch_list).cuda()
    
def get_ACDC_masks(output, nms=0):
    probs = F.softmax(output, dim=1)
    _, probs = torch.max(probs, dim=1)
    if nms == 1:
        probs = get_ACDC_2DLargestCC(probs)      
    return probs

In [14]:
def save_net_opt(net, optimizer, path):
    state = {
        'net': net.state_dict(),
        'optim': optimizer.state_dict()
    }
    torch.save(state, str(path))

def load_net(net, path):
    state = torch.load(str(path))
    net.load_state_dict(state['net']) 

def load_net_opt(net, optimizer, path): 
    state = torch.load(str(path))
    net.load_state_dict(state['net'])
    optimizer.load_state_dict(state['optim'])

In [15]:
def generate_mask(img):
    batch_size, channel, img_x, img_y = img.shape[0], img.shape[1], img.shape[2], img.shape[3]
    loss_mask = torch.ones(batch_size, img_x, img_y).cuda()
    mask = torch.ones(img_x, img_y).cuda()
    patch_x, patch_y = int(img_x*2/3), int(img_y*2/3)
    w = np.random.randint(0, img_x - patch_x)
    h = np.random.randint(0, img_y - patch_y)
    mask[w:w+patch_x, h:h+patch_y] = 0
    loss_mask[:, w:w+patch_x, h:h+patch_y] = 0
    return mask.long(), loss_mask.long()

In [16]:
def sigmoid_rampup(current, rampup_length):
    if rampup_length == 0: 
        return 1.0 
    else:
        current = np.clip(current, 0, rampup_length)
        phase = 1 - (current / rampup_length)
        return float(np.exp(-5 * phase * phase))

In [17]:
# Mean-Teacher compomnent 
def get_current_consistency_weight(epoch, args): 
    return 5 * args.consistency + sigmoid_rampup(epoch, args.consistency_rampup)

def update_model_ema(model, ema_model, alpha): 
    model_state = model.state_dict() 
    model_ema_state = ema_model.state_dict()


    new_dict = {} 

    for key in model_state:
        new_dict[key] = alpha * model_ema_state[key] + (1 - alpha) * model_state[key]

    ema_model.load_state_dict(new_dict)

In [18]:
# Configuration
def pre_train(args, snapshot_path):
    base_lr = args.base_lr
    num_classes = args.num_classes
    max_iterations = args.pretrain_iterations
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 
    pre_trained_model = os.path.join(snapshot_path, '{}_best_model.pth'.format(args.model))
    labeled_sub_bs, unlabeled_sub_bs = int(args.labeled_bs / 2), int((args.batch_size - args.labeled_bs)/2)

    def worker_init_fn(worker_id):
        random.seed(args.seed + worker_id)

    # load dataset 
    db_train = BaseDataset(root_path = args.root_dir, split= 'train', 
                        num= None, 
                        transform= transforms.Compose([RandomGenerator(args.patch_size)]))
    
    db_val = BaseDataset(root_path= args.root_dir, split= 'val')

    total_slices = len(db_train)
    labeled_slices = patients_to_slices(args.root_dir, args.label_num)
    print(f'Total slice is {total_slices}, Labeled slice is {labeled_slices}')

    # Create batch_sampler 
    labeled_idxs = list(range(0, labeled_slices))
    unlabeled_idxs = list(range(labeled_slices, total_slices))
    batch_sampler = TwoStreamBatchSampler(labeled_idxs, unlabeled_idxs, args.batch_size, args.batch_size - args.labeled_bs)

    # Create dataloader 
    trainloader = DataLoader(db_train, batch_sampler= batch_sampler, num_workers= 4, pin_memory= True, worker_init_fn= worker_init_fn)
    valloader = DataLoader(db_val, batch_size= 1, shuffle= False, num_workers=1)

    # Define model 
    model = BCP_net(in_chns=1, class_num= num_classes)
    optimizer = optim.SGD(model.parameters(), lr= base_lr, momentum= 0.9, weight_decay= 0.0001)

    writer = SummaryWriter(snapshot_path + '/log')
    logging.info('Start pre-training')
    logging.info(f'{len(trainloader)} iterations per epoch')

    # training process
    model.train() 
    iter_num = 0 
    max_epoch = max_iterations // len(trainloader) + 1 
    best_performance = 0.0 
    best_hd = 100.0
    iterator = tqdm(range(max_epoch), ncols= 70)
    
    for _ in iterator: 
        for _, sampled_batch in enumerate(trainloader): 
            volume_batch, label_batch = sampled_batch['image'], sampled_batch['label']
            volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda() 

            img_a, img_b = volume_batch[: labeled_sub_bs], volume_batch[labeled_sub_bs : args.labeled_bs]
            lab_a, lab_b = label_batch[: labeled_sub_bs], label_batch[labeled_sub_bs : args.labeled_bs]
            img_mask, loss_mask = generate_mask(img_a)
            gt_mixl = lab_a * img_mask + lab_b * ( 1- img_mask)

            #-- original 
            net_input = img_a * img_mask + img_b * ( 1 - img_mask)  
            out_mixl = model(net_input)
            loss_dice, loss_ce = mix_loss(out_mixl, lab_a, lab_b, loss_mask,u_weight= 1.0, unlab= True )
            loss = (loss_dice + loss_ce )/2 

            optimizer.zero_grad() 
            loss.backward() 
            optimizer.step() 

            iter_num += 1 

            writer.add_scalar('info/total_loss', loss, iter_num)
            writer.add_scalar('info/mix_dice', loss_dice,iter_num )
            writer.add_scalar('info/mix_ce', loss_ce, iter_num)

            logging.info('iteration %d: loss %f, mix_dice: %f, mix_ce: %f'%(iter_num, loss, loss_dice, loss_ce))
            if iter_num % 20 == 0: 
                image = net_input[1, 0:1, :, :]
                writer.add_image('pre_train/Mixed_Image', image, iter_num)
                outputs = torch.argmax(torch.softmax(out_mixl, dim=1), dim=1, keepdim= True)
                writer.add_image('pre_train/Mixed_Prediction', outputs[1, ...]*50, iter_num)
                labs = gt_mixl[1, ...].unsqueeze(0) * 50 
                writer.add_image('pre_train/Mixed_GroundTruth', labs, iter_num)
            
            # Evaluate after 200 epoch ! 
            if iter_num > 0 and iter_num % 200 == 0: 
                model.eval() 
                metric_list = 0.0 
                for _, sampled_batch in enumerate(valloader):
                    metric_i = test_single_volume(sampled_batch['image'], sampled_batch['label'], model, classes= num_classes)
                    metric_list += np.array(metric_i)
                
                metric_list = metric_list / len(db_val)
                
    
                for class_i in range(num_classes - 1 ): 
                    writer.add_scalar('info/val_{}_dice'.format(class_i + 1), metric_list[class_i, 0], iter_num)
                    writer.add_scalar('infor/val_{}_hd'.format(class_i + 1), metric_list[class_i, 1], iter_num)
                
                performance = np.mean(metric_list, axis=0)[0]
                writer.add_scalar('info/val_mean_dice', performance, iter_num)
            
                if performance > best_performance: 
                    best_performance = performance
                    save_model_path = os.path.join(snapshot_path, 'iter_{}_dice_{}.pth'.format(iter_num, round(best_performance,4)))
                    save_best_path = os.path.join(snapshot_path, '{}_best_model.pth'.format(args.model))
                    save_net_opt(model, optimizer, save_model_path)
                    save_net_opt(model, optimizer, save_best_path)
                
                logging.info('iteration %d : mean dice : %f'%(iter_num, performance))
                model.train() 
            
            if iter_num >= max_iterations: 
                break 
        
        if iter_num >= max_iterations: 
            iterator.close() 
            break 
    
    writer.close()

In [19]:
def self_train(args ,pre_snapshot_path, snapshot_path):
    base_lr = args.base_lr
    num_classes = args.num_classes
    max_iterations = args.selftrain_iterations
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    pre_trained_model = os.path.join(pre_snapshot_path,'{}_best_model.pth'.format(args.model))
    labeled_sub_bs, unlabeled_sub_bs = int(args.labeled_bs/2), int((args.batch_size-args.labeled_bs) / 2)
     
    model = BCP_net(in_chns=1, class_num=num_classes)
    ema_model = BCP_net(in_chns=1, class_num=num_classes, ema=True)

    def worker_init_fn(worker_id):
        random.seed(args.seed + worker_id)

    db_train = BaseDataset(root_path=args.root_dir,
                            split="train",
                            num=None,
                            transform=transforms.Compose([RandomGenerator(args.patch_size)]))
    db_val = BaseDataset(root_path=args.root_dir, split="val")
    total_slices = len(db_train)
    labeled_slice = patients_to_slices(args.root_dir,args.label_num)
    print("Total slices is: {}, labeled slices is:{}".format(total_slices, labeled_slice))
    labeled_idxs = list(range(0, labeled_slice))
    unlabeled_idxs = list(range(labeled_slice, total_slices))
    batch_sampler = TwoStreamBatchSampler(labeled_idxs, unlabeled_idxs, args.batch_size, args.batch_size-args.labeled_bs)

    trainloader = DataLoader(db_train, batch_sampler=batch_sampler, num_workers=4, pin_memory=True, worker_init_fn=worker_init_fn)

    valloader = DataLoader(db_val, batch_size=1, shuffle=False, num_workers=1)

    optimizer = optim.SGD(model.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0001)
    load_net(ema_model, pre_trained_model)
    load_net_opt(model, optimizer, pre_trained_model) #
    logging.info("Loaded from {}".format(pre_trained_model))

    writer = SummaryWriter(snapshot_path + '/log')
    logging.info("Start self_training")
    logging.info("{} iterations per epoch".format(len(trainloader)))

    model.train()
    ema_model.train()

    ce_loss = CrossEntropyLoss()

    iter_num = 0
    max_epoch = max_iterations // len(trainloader) + 1
    best_performance = 0.0
    best_hd = 100
    iterator = tqdm(range(max_epoch), ncols=70)
    for _ in iterator:
        for _, sampled_batch in enumerate(trainloader):
            
            volume_batch, label_batch = sampled_batch['image'], sampled_batch['label']
            volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda()

            img_a, img_b = volume_batch[:labeled_sub_bs], volume_batch[labeled_sub_bs:args.labeled_bs]
            uimg_a, uimg_b = volume_batch[args.labeled_bs:args.labeled_bs + unlabeled_sub_bs], volume_batch[args.labeled_bs + unlabeled_sub_bs:]
            ulab_a, ulab_b = label_batch[args.labeled_bs:args.labeled_bs + unlabeled_sub_bs], label_batch[args.labeled_bs + unlabeled_sub_bs:]
            lab_a, lab_b = label_batch[:labeled_sub_bs], label_batch[labeled_sub_bs:args.labeled_bs]
            with torch.no_grad():
                pre_a = ema_model(uimg_a) # pseudo label a 
                pre_b = ema_model(uimg_b) # pseudo label b 
                plab_a = get_ACDC_masks(pre_a, nms=1)
                plab_b = get_ACDC_masks(pre_b, nms=1)
                img_mask, loss_mask = generate_mask(img_a)
                unl_label = ulab_a * img_mask + lab_a * (1 - img_mask)
                l_label = lab_b * img_mask + ulab_b * (1 - img_mask)
            consistency_weight = get_current_consistency_weight(iter_num//150,args)

            net_input_unl = uimg_a * img_mask + img_a * (1 - img_mask)
            net_input_l = img_b * img_mask + uimg_b * (1 - img_mask)
            out_unl = model(net_input_unl)
            out_l = model(net_input_l)
            unl_dice, unl_ce = mix_loss(out_unl, plab_a, lab_a, loss_mask, u_weight=args.u_weight, unlab=True)
            l_dice, l_ce = mix_loss(out_l, lab_b, plab_b, loss_mask, u_weight=args.u_weight)


            loss_ce = unl_ce + l_ce 
            loss_dice = unl_dice + l_dice

            loss = (loss_dice + loss_ce) / 2            

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            iter_num += 1
            update_model_ema(model, ema_model, 0.99)

            writer.add_scalar('info/total_loss', loss, iter_num)
            writer.add_scalar('info/mix_dice', loss_dice, iter_num)
            writer.add_scalar('info/mix_ce', loss_ce, iter_num)
            writer.add_scalar('info/consistency_weight', consistency_weight, iter_num)     

            logging.info('iteration %d: loss: %f, mix_dice: %f, mix_ce: %f'%(iter_num, loss, loss_dice, loss_ce))
                
            if iter_num % 20 == 0:
                image = net_input_unl[1, 0:1, :, :]
                writer.add_image('train/Un_Image', image, iter_num)
                outputs = torch.argmax(torch.softmax(out_unl, dim=1), dim=1, keepdim=True)
                writer.add_image('train/Un_Prediction', outputs[1, ...] * 50, iter_num)
                labs = unl_label[1, ...].unsqueeze(0) * 50
                writer.add_image('train/Un_GroundTruth', labs, iter_num)

                image_l = net_input_l[1, 0:1, :, :]
                writer.add_image('train/L_Image', image_l, iter_num)
                outputs_l = torch.argmax(torch.softmax(out_l, dim=1), dim=1, keepdim=True)
                writer.add_image('train/L_Prediction', outputs_l[1, ...] * 50, iter_num)
                labs_l = l_label[1, ...].unsqueeze(0) * 50
                writer.add_image('train/L_GroundTruth', labs_l, iter_num)

            if iter_num > 0 and iter_num % 200 == 0:
                model.eval()
                metric_list = 0.0
                for _, sampled_batch in enumerate(valloader):
                    metric_i = test_single_volume(sampled_batch["image"], sampled_batch["label"], model, classes= num_classes)
                    metric_list += np.array(metric_i)
                metric_list = metric_list / len(db_val)
                print(f'Metric list: {metric_list}') 
                for class_i in range(num_classes-1):
                    writer.add_scalar('info/val_{}_dice'.format(class_i+1), metric_list[class_i, 0], iter_num)
                    writer.add_scalar('info/val_{}_hd95'.format(class_i+1), metric_list[class_i, 1], iter_num)

                performance = np.mean(metric_list, axis=0)[0]
                writer.add_scalar('info/val_mean_dice', performance, iter_num)

                if performance > best_performance:
                    best_performance = performance
                    save_mode_path = os.path.join(snapshot_path, 'iter_{}_dice_{}.pth'.format(iter_num, round(best_performance, 4)))
                    save_best_path = os.path.join(snapshot_path,'{}_best_model.pth'.format(args.model))
                    torch.save(model.state_dict(), save_mode_path)
                    torch.save(model.state_dict(), save_best_path)

                logging.info('iteration %d : mean_dice : %f' % (iter_num, performance))
                model.train()

            if iter_num >= max_iterations:
                break
        if iter_num >= max_iterations:
            iterator.close()
            break
    writer.close()


In [20]:
# Params 
class params: 
    def __init__(self): 
        self.root_dir = 'ACDC' 
        self.exp = 'BCP' 
        self.model = 'unet' 
        self.pretrain_iterations = 400 
        
        self.selftrain_iterations = 10 
        self.batch_size = 24
        self.deterministic = 1 # What the fucck here
        self.base_lr = 0.01 
        self.patch_size = [256,256] 
        self.seed = 42 
        self.num_classes = 4 

        # label and unlabel 
        self.labeled_bs = 12
        self.label_num = 7 
        self.u_weight = 0.5 

        # Cost 
        self.gpu = '0' 
        self.consistency = 0.1
        self.consistency_rampup = 200.0 
        self.magnitude = '6.0' 
        self.s_param = 6 


args = params()
# Trainig process 
if args.deterministic: 
    cudnn.benchmark = False 
    cudnn.deterministic = True 
    random.seed(args.seed) 
    np.random.seed(args.seed) 
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

# Path 
pre_snapshot_path = "./model/BCP/ACDC_{}_{}_labeled/pretrain".format(args.exp, args.label_num)
self_snapshot_path = "./model/BCP/ACDC_{}_{}_labeled/selftrain".format(args.exp, args.label_num)

print(f'Pretrain log path: {pre_snapshot_path + "/log.txt"}')
print(f'Self-train log path: {self_snapshot_path + "/log.txt"}')

for snapshot_path in [pre_snapshot_path, self_snapshot_path]: 
    if not os.path.exists(snapshot_path): 
        os.makedirs(snapshot_path, exist_ok= True)
#Pre_train
logging.basicConfig(filename=pre_snapshot_path+"/log.txt", level=logging.INFO, format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S')
logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
pre_train(args, pre_snapshot_path)

#Self_train
logging.basicConfig(filename=self_snapshot_path+"/log.txt", level=logging.INFO, format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S')
logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
self_train(args, pre_snapshot_path, self_snapshot_path)

Pretrain log path: ./model/BCP/ACDC_BCP_7_labeled/pretrain/log.txt
Self-train log path: ./model/BCP/ACDC_BCP_7_labeled/selftrain/log.txt
Total slices: 1312
Total slices: 20
Total slice is 1312, Labeled slice is 136
Start pre-training
11 iterations per epoch


  0%|                                          | 0/37 [00:00<?, ?it/s]

iteration 1: loss 2.311362, mix_dice: 1.668222, mix_ce: 2.954502
iteration 2: loss 2.158126, mix_dice: 1.645628, mix_ce: 2.670624
iteration 3: loss 1.882130, mix_dice: 1.587863, mix_ce: 2.176397
iteration 4: loss 1.582077, mix_dice: 1.514468, mix_ce: 1.649685
iteration 5: loss 1.322685, mix_dice: 1.451923, mix_ce: 1.193448
iteration 6: loss 1.119387, mix_dice: 1.405992, mix_ce: 0.832782
iteration 7: loss 1.010115, mix_dice: 1.379674, mix_ce: 0.640556
iteration 8: loss 0.916543, mix_dice: 1.388362, mix_ce: 0.444724
iteration 9: loss 0.887238, mix_dice: 1.378385, mix_ce: 0.396090
iteration 10: loss 0.892786, mix_dice: 1.425456, mix_ce: 0.360116
iteration 11: loss 0.947268, mix_dice: 1.382958, mix_ce: 0.511579


  3%|▉                                 | 1/37 [00:01<00:45,  1.27s/it]

iteration 12: loss 0.917974, mix_dice: 1.382751, mix_ce: 0.453198
iteration 13: loss 0.876567, mix_dice: 1.341477, mix_ce: 0.411657
iteration 14: loss 0.943996, mix_dice: 1.410683, mix_ce: 0.477310
iteration 15: loss 0.898264, mix_dice: 1.396505, mix_ce: 0.400024
iteration 16: loss 1.000480, mix_dice: 1.412440, mix_ce: 0.588519
iteration 17: loss 0.886703, mix_dice: 1.408946, mix_ce: 0.364461
iteration 18: loss 0.902835, mix_dice: 1.341402, mix_ce: 0.464269
iteration 19: loss 0.941350, mix_dice: 1.414738, mix_ce: 0.467961
iteration 20: loss 0.869380, mix_dice: 1.434147, mix_ce: 0.304614
iteration 21: loss 0.949446, mix_dice: 1.426566, mix_ce: 0.472326
iteration 22: loss 0.856860, mix_dice: 1.388643, mix_ce: 0.325076


  5%|█▊                                | 2/37 [00:02<00:34,  1.02it/s]

iteration 23: loss 0.890224, mix_dice: 1.380850, mix_ce: 0.399598
iteration 24: loss 0.879951, mix_dice: 1.355667, mix_ce: 0.404234
iteration 25: loss 1.023263, mix_dice: 1.399924, mix_ce: 0.646603
iteration 26: loss 0.859988, mix_dice: 1.296845, mix_ce: 0.423131
iteration 27: loss 0.844160, mix_dice: 1.352127, mix_ce: 0.336194
iteration 28: loss 0.887816, mix_dice: 1.373516, mix_ce: 0.402116
iteration 29: loss 0.869121, mix_dice: 1.357193, mix_ce: 0.381049
iteration 30: loss 0.993220, mix_dice: 1.455630, mix_ce: 0.530810
iteration 31: loss 0.958963, mix_dice: 1.409350, mix_ce: 0.508576
iteration 32: loss 0.901464, mix_dice: 1.389937, mix_ce: 0.412992
iteration 33: loss 0.842478, mix_dice: 1.326654, mix_ce: 0.358303


  8%|██▊                               | 3/37 [00:02<00:30,  1.13it/s]

iteration 34: loss 0.931241, mix_dice: 1.398945, mix_ce: 0.463538
iteration 35: loss 0.841114, mix_dice: 1.296254, mix_ce: 0.385973
iteration 36: loss 0.910951, mix_dice: 1.359834, mix_ce: 0.462068
iteration 37: loss 0.923201, mix_dice: 1.402750, mix_ce: 0.443652
iteration 38: loss 0.883460, mix_dice: 1.374273, mix_ce: 0.392647
iteration 39: loss 0.848368, mix_dice: 1.419623, mix_ce: 0.277114
iteration 40: loss 0.867007, mix_dice: 1.363967, mix_ce: 0.370048
iteration 41: loss 0.891754, mix_dice: 1.431462, mix_ce: 0.352046
iteration 42: loss 0.891011, mix_dice: 1.399373, mix_ce: 0.382648
iteration 43: loss 0.833642, mix_dice: 1.362319, mix_ce: 0.304964
iteration 44: loss 0.814097, mix_dice: 1.340514, mix_ce: 0.287680


 11%|███▋                              | 4/37 [00:03<00:27,  1.19it/s]

iteration 45: loss 0.793858, mix_dice: 1.284611, mix_ce: 0.303105
iteration 46: loss 0.936135, mix_dice: 1.409691, mix_ce: 0.462579
iteration 47: loss 0.862305, mix_dice: 1.323087, mix_ce: 0.401522
iteration 48: loss 0.828952, mix_dice: 1.364495, mix_ce: 0.293410
iteration 49: loss 0.794750, mix_dice: 1.297238, mix_ce: 0.292262
iteration 50: loss 0.881786, mix_dice: 1.319920, mix_ce: 0.443653
iteration 51: loss 0.862831, mix_dice: 1.420365, mix_ce: 0.305298
iteration 52: loss 0.804362, mix_dice: 1.332614, mix_ce: 0.276110
iteration 53: loss 0.836402, mix_dice: 1.321966, mix_ce: 0.350838
iteration 54: loss 0.867220, mix_dice: 1.363717, mix_ce: 0.370724
iteration 55: loss 0.796847, mix_dice: 1.204323, mix_ce: 0.389370


 14%|████▌                             | 5/37 [00:04<00:25,  1.25it/s]

iteration 56: loss 0.795989, mix_dice: 1.315032, mix_ce: 0.276947
iteration 57: loss 0.817172, mix_dice: 1.257257, mix_ce: 0.377088
iteration 58: loss 0.863400, mix_dice: 1.300637, mix_ce: 0.426162
iteration 59: loss 0.850879, mix_dice: 1.351191, mix_ce: 0.350568
iteration 60: loss 0.831964, mix_dice: 1.268912, mix_ce: 0.395017
iteration 61: loss 0.888989, mix_dice: 1.359660, mix_ce: 0.418318
iteration 62: loss 0.846825, mix_dice: 1.303778, mix_ce: 0.389872
iteration 63: loss 0.756408, mix_dice: 1.237477, mix_ce: 0.275340
iteration 64: loss 0.876660, mix_dice: 1.338503, mix_ce: 0.414817
iteration 65: loss 0.758356, mix_dice: 1.207831, mix_ce: 0.308881
iteration 66: loss 0.873477, mix_dice: 1.352881, mix_ce: 0.394073


 16%|█████▌                            | 6/37 [00:05<00:24,  1.25it/s]

iteration 67: loss 0.840473, mix_dice: 1.362358, mix_ce: 0.318589
iteration 68: loss 0.776336, mix_dice: 1.271743, mix_ce: 0.280930
iteration 69: loss 0.839200, mix_dice: 1.340083, mix_ce: 0.338317
iteration 70: loss 0.841912, mix_dice: 1.300637, mix_ce: 0.383187
iteration 71: loss 0.963508, mix_dice: 1.412705, mix_ce: 0.514312
iteration 72: loss 0.776874, mix_dice: 1.250453, mix_ce: 0.303294
iteration 73: loss 0.811730, mix_dice: 1.342596, mix_ce: 0.280863
iteration 74: loss 0.799386, mix_dice: 1.243015, mix_ce: 0.355756
iteration 75: loss 0.784466, mix_dice: 1.268545, mix_ce: 0.300387
iteration 76: loss 0.741411, mix_dice: 1.211984, mix_ce: 0.270837
iteration 77: loss 0.837829, mix_dice: 1.298498, mix_ce: 0.377160


 19%|██████▍                           | 7/37 [00:05<00:23,  1.28it/s]

iteration 78: loss 0.781686, mix_dice: 1.259284, mix_ce: 0.304088
iteration 79: loss 0.827787, mix_dice: 1.325112, mix_ce: 0.330462
iteration 80: loss 0.786571, mix_dice: 1.316383, mix_ce: 0.256760
iteration 81: loss 0.801458, mix_dice: 1.236107, mix_ce: 0.366808
iteration 82: loss 0.803922, mix_dice: 1.265287, mix_ce: 0.342557
iteration 83: loss 0.748440, mix_dice: 1.220730, mix_ce: 0.276150
iteration 84: loss 0.946557, mix_dice: 1.467619, mix_ce: 0.425495
iteration 85: loss 0.811553, mix_dice: 1.318577, mix_ce: 0.304529
iteration 86: loss 0.759938, mix_dice: 1.224133, mix_ce: 0.295744
iteration 87: loss 0.817254, mix_dice: 1.311656, mix_ce: 0.322851
iteration 88: loss 0.769758, mix_dice: 1.278936, mix_ce: 0.260580


 22%|███████▎                          | 8/37 [00:06<00:22,  1.27it/s]

iteration 89: loss 0.761085, mix_dice: 1.205491, mix_ce: 0.316680
iteration 90: loss 0.796581, mix_dice: 1.271668, mix_ce: 0.321494
iteration 91: loss 0.812811, mix_dice: 1.309918, mix_ce: 0.315705
iteration 92: loss 0.733269, mix_dice: 1.150733, mix_ce: 0.315805
iteration 93: loss 0.842396, mix_dice: 1.311216, mix_ce: 0.373576
iteration 94: loss 0.776053, mix_dice: 1.255898, mix_ce: 0.296208
iteration 95: loss 0.759374, mix_dice: 1.215813, mix_ce: 0.302936
iteration 96: loss 0.785344, mix_dice: 1.260904, mix_ce: 0.309783
iteration 97: loss 0.840625, mix_dice: 1.333210, mix_ce: 0.348041
iteration 98: loss 0.685947, mix_dice: 1.164007, mix_ce: 0.207887
iteration 99: loss 0.729871, mix_dice: 1.191180, mix_ce: 0.268563


 24%|████████▎                         | 9/37 [00:07<00:21,  1.29it/s]

iteration 100: loss 0.767697, mix_dice: 1.220072, mix_ce: 0.315322
iteration 101: loss 0.726007, mix_dice: 1.237682, mix_ce: 0.214332
iteration 102: loss 0.770468, mix_dice: 1.280253, mix_ce: 0.260683
iteration 103: loss 0.712289, mix_dice: 1.217557, mix_ce: 0.207021
iteration 104: loss 0.837080, mix_dice: 1.283152, mix_ce: 0.391007
iteration 105: loss 0.719962, mix_dice: 1.231272, mix_ce: 0.208651
iteration 106: loss 0.703168, mix_dice: 1.214529, mix_ce: 0.191808
iteration 107: loss 0.767665, mix_dice: 1.285316, mix_ce: 0.250015
iteration 108: loss 0.768623, mix_dice: 1.239977, mix_ce: 0.297269
iteration 109: loss 0.758055, mix_dice: 1.214642, mix_ce: 0.301468
iteration 110: loss 0.771640, mix_dice: 1.275182, mix_ce: 0.268097


 27%|████████▉                        | 10/37 [00:08<00:21,  1.27it/s]

iteration 111: loss 0.752620, mix_dice: 1.274833, mix_ce: 0.230407
iteration 112: loss 0.745580, mix_dice: 1.208035, mix_ce: 0.283125
iteration 113: loss 0.740859, mix_dice: 1.198487, mix_ce: 0.283231
iteration 114: loss 0.734811, mix_dice: 1.198542, mix_ce: 0.271081
iteration 115: loss 0.747350, mix_dice: 1.202859, mix_ce: 0.291840
iteration 116: loss 0.697970, mix_dice: 1.150135, mix_ce: 0.245804
iteration 117: loss 0.706288, mix_dice: 1.181973, mix_ce: 0.230603
iteration 118: loss 0.669068, mix_dice: 1.123816, mix_ce: 0.214319
iteration 119: loss 0.744267, mix_dice: 1.183625, mix_ce: 0.304908
iteration 120: loss 0.754252, mix_dice: 1.160566, mix_ce: 0.347938
iteration 121: loss 0.833036, mix_dice: 1.317243, mix_ce: 0.348828


 30%|█████████▊                       | 11/37 [00:09<00:20,  1.27it/s]

iteration 122: loss 0.762515, mix_dice: 1.206738, mix_ce: 0.318291
iteration 123: loss 0.776054, mix_dice: 1.239038, mix_ce: 0.313069
iteration 124: loss 0.760044, mix_dice: 1.262785, mix_ce: 0.257304
iteration 125: loss 0.733408, mix_dice: 1.247697, mix_ce: 0.219120
iteration 126: loss 0.778673, mix_dice: 1.275612, mix_ce: 0.281735
iteration 127: loss 0.688041, mix_dice: 1.165775, mix_ce: 0.210307
iteration 128: loss 0.776644, mix_dice: 1.267421, mix_ce: 0.285866
iteration 129: loss 0.705531, mix_dice: 1.182367, mix_ce: 0.228695
iteration 130: loss 0.669021, mix_dice: 1.133118, mix_ce: 0.204923
iteration 131: loss 0.762300, mix_dice: 1.183152, mix_ce: 0.341449
iteration 132: loss 0.834363, mix_dice: 1.344498, mix_ce: 0.324229


 32%|██████████▋                      | 12/37 [00:09<00:19,  1.28it/s]

iteration 133: loss 0.801739, mix_dice: 1.266932, mix_ce: 0.336547
iteration 134: loss 0.731070, mix_dice: 1.232173, mix_ce: 0.229966
iteration 135: loss 0.727155, mix_dice: 1.213759, mix_ce: 0.240552
iteration 136: loss 0.840599, mix_dice: 1.313314, mix_ce: 0.367884
iteration 137: loss 0.717452, mix_dice: 1.210027, mix_ce: 0.224878
iteration 138: loss 0.751821, mix_dice: 1.257536, mix_ce: 0.246106
iteration 139: loss 0.747064, mix_dice: 1.237794, mix_ce: 0.256333
iteration 140: loss 0.764252, mix_dice: 1.223507, mix_ce: 0.304997
iteration 141: loss 0.680447, mix_dice: 1.149424, mix_ce: 0.211471
iteration 142: loss 0.679086, mix_dice: 1.155116, mix_ce: 0.203056
iteration 143: loss 0.916033, mix_dice: 1.410600, mix_ce: 0.421466


 35%|███████████▌                     | 13/37 [00:10<00:18,  1.28it/s]

iteration 144: loss 0.845523, mix_dice: 1.323548, mix_ce: 0.367497
iteration 145: loss 0.735003, mix_dice: 1.218303, mix_ce: 0.251704
iteration 146: loss 0.771424, mix_dice: 1.231159, mix_ce: 0.311690
iteration 147: loss 0.737230, mix_dice: 1.234417, mix_ce: 0.240043
iteration 148: loss 0.718602, mix_dice: 1.194640, mix_ce: 0.242563
iteration 149: loss 0.742863, mix_dice: 1.207501, mix_ce: 0.278225
iteration 150: loss 0.743219, mix_dice: 1.240805, mix_ce: 0.245633
iteration 151: loss 0.627685, mix_dice: 1.045539, mix_ce: 0.209831
iteration 152: loss 0.706826, mix_dice: 1.200148, mix_ce: 0.213504
iteration 153: loss 0.706032, mix_dice: 1.216519, mix_ce: 0.195544
iteration 154: loss 0.743828, mix_dice: 1.208741, mix_ce: 0.278916


 38%|████████████▍                    | 14/37 [00:11<00:17,  1.30it/s]

iteration 155: loss 0.722928, mix_dice: 1.200543, mix_ce: 0.245313
iteration 156: loss 0.689518, mix_dice: 1.165928, mix_ce: 0.213108
iteration 157: loss 0.769561, mix_dice: 1.272362, mix_ce: 0.266759
iteration 158: loss 0.690949, mix_dice: 1.166093, mix_ce: 0.215804
iteration 159: loss 0.785410, mix_dice: 1.240302, mix_ce: 0.330518
iteration 160: loss 0.667881, mix_dice: 1.157517, mix_ce: 0.178246
iteration 161: loss 0.756016, mix_dice: 1.255569, mix_ce: 0.256462
iteration 162: loss 0.845638, mix_dice: 1.352599, mix_ce: 0.338677
iteration 163: loss 0.733664, mix_dice: 1.184262, mix_ce: 0.283067
iteration 164: loss 0.713299, mix_dice: 1.165889, mix_ce: 0.260708
iteration 165: loss 0.803324, mix_dice: 1.278765, mix_ce: 0.327883


 41%|█████████████▍                   | 15/37 [00:12<00:17,  1.26it/s]

iteration 166: loss 0.682841, mix_dice: 1.182991, mix_ce: 0.182691
iteration 167: loss 0.762205, mix_dice: 1.220304, mix_ce: 0.304105
iteration 168: loss 0.795663, mix_dice: 1.276251, mix_ce: 0.315074
iteration 169: loss 0.692642, mix_dice: 1.152909, mix_ce: 0.232374
iteration 170: loss 0.745149, mix_dice: 1.234298, mix_ce: 0.256000
iteration 171: loss 0.651603, mix_dice: 1.133370, mix_ce: 0.169837
iteration 172: loss 0.678033, mix_dice: 1.113343, mix_ce: 0.242724
iteration 173: loss 0.747644, mix_dice: 1.203048, mix_ce: 0.292240
iteration 174: loss 0.708499, mix_dice: 1.205441, mix_ce: 0.211558
iteration 175: loss 0.732032, mix_dice: 1.184566, mix_ce: 0.279497
iteration 176: loss 0.696603, mix_dice: 1.213626, mix_ce: 0.179579


 43%|██████████████▎                  | 16/37 [00:12<00:16,  1.28it/s]

iteration 177: loss 0.715010, mix_dice: 1.221772, mix_ce: 0.208248
iteration 178: loss 0.729056, mix_dice: 1.189210, mix_ce: 0.268901
iteration 179: loss 0.764040, mix_dice: 1.251309, mix_ce: 0.276770
iteration 180: loss 0.622627, mix_dice: 1.090222, mix_ce: 0.155032
iteration 181: loss 0.683726, mix_dice: 1.110645, mix_ce: 0.256807
iteration 182: loss 0.685832, mix_dice: 1.127025, mix_ce: 0.244640
iteration 183: loss 0.726522, mix_dice: 1.227275, mix_ce: 0.225770
iteration 184: loss 0.621318, mix_dice: 1.028063, mix_ce: 0.214573
iteration 185: loss 0.692490, mix_dice: 1.168275, mix_ce: 0.216706
iteration 186: loss 0.741329, mix_dice: 1.215802, mix_ce: 0.266855
iteration 187: loss 0.642837, mix_dice: 1.102865, mix_ce: 0.182809


 46%|███████████████▏                 | 17/37 [00:13<00:15,  1.28it/s]

iteration 188: loss 0.721659, mix_dice: 1.212684, mix_ce: 0.230634
iteration 189: loss 0.722847, mix_dice: 1.222517, mix_ce: 0.223176
iteration 190: loss 0.626972, mix_dice: 1.073727, mix_ce: 0.180216
iteration 191: loss 0.697737, mix_dice: 1.192454, mix_ce: 0.203020
iteration 192: loss 0.774064, mix_dice: 1.284912, mix_ce: 0.263216
iteration 193: loss 0.737606, mix_dice: 1.212446, mix_ce: 0.262766
iteration 194: loss 0.707990, mix_dice: 1.232242, mix_ce: 0.183738
iteration 195: loss 0.690944, mix_dice: 1.182533, mix_ce: 0.199356
iteration 196: loss 0.661576, mix_dice: 1.132063, mix_ce: 0.191089
iteration 197: loss 0.752214, mix_dice: 1.174462, mix_ce: 0.329965
iteration 198: loss 0.684943, mix_dice: 1.174982, mix_ce: 0.194904


 49%|████████████████                 | 18/37 [00:14<00:14,  1.30it/s]

iteration 199: loss 0.748723, mix_dice: 1.218002, mix_ce: 0.279443
iteration 200: loss 0.659323, mix_dice: 1.119105, mix_ce: 0.199541
iteration 200 : mean dice : 0.334669
iteration 201: loss 0.673093, mix_dice: 1.118562, mix_ce: 0.227625
iteration 202: loss 0.686073, mix_dice: 1.149149, mix_ce: 0.222996
iteration 203: loss 0.637258, mix_dice: 1.072922, mix_ce: 0.201594
iteration 204: loss 0.626628, mix_dice: 1.066600, mix_ce: 0.186656
iteration 205: loss 0.752598, mix_dice: 1.273502, mix_ce: 0.231693
iteration 206: loss 0.660614, mix_dice: 1.029794, mix_ce: 0.291434
iteration 207: loss 0.682192, mix_dice: 1.151660, mix_ce: 0.212724
iteration 208: loss 0.728150, mix_dice: 1.120447, mix_ce: 0.335853
iteration 209: loss 0.656492, mix_dice: 1.133996, mix_ce: 0.178989


 51%|████████████████▉                | 19/37 [00:22<00:53,  2.98s/it]

iteration 210: loss 0.542158, mix_dice: 0.920984, mix_ce: 0.163332
iteration 211: loss 0.661399, mix_dice: 1.121684, mix_ce: 0.201114
iteration 212: loss 0.617906, mix_dice: 1.035367, mix_ce: 0.200446
iteration 213: loss 0.684319, mix_dice: 1.149481, mix_ce: 0.219157
iteration 214: loss 0.622772, mix_dice: 1.055501, mix_ce: 0.190043
iteration 215: loss 0.562201, mix_dice: 0.828294, mix_ce: 0.296108
iteration 216: loss 0.701813, mix_dice: 1.152023, mix_ce: 0.251603
iteration 217: loss 0.709609, mix_dice: 1.176254, mix_ce: 0.242963
iteration 218: loss 0.689076, mix_dice: 1.170355, mix_ce: 0.207796
iteration 219: loss 0.752823, mix_dice: 1.267221, mix_ce: 0.238424
iteration 220: loss 0.671349, mix_dice: 1.146641, mix_ce: 0.196058


 54%|█████████████████▊               | 20/37 [00:23<00:40,  2.36s/it]

iteration 221: loss 0.573642, mix_dice: 0.957881, mix_ce: 0.189404
iteration 222: loss 0.633645, mix_dice: 1.098855, mix_ce: 0.168434
iteration 223: loss 0.649471, mix_dice: 1.099516, mix_ce: 0.199425
iteration 224: loss 0.670605, mix_dice: 1.045474, mix_ce: 0.295735
iteration 225: loss 0.638438, mix_dice: 1.085541, mix_ce: 0.191334
iteration 226: loss 0.655848, mix_dice: 1.081789, mix_ce: 0.229906
iteration 227: loss 0.727592, mix_dice: 1.191179, mix_ce: 0.264005
iteration 228: loss 0.759116, mix_dice: 1.158554, mix_ce: 0.359679
iteration 229: loss 0.669001, mix_dice: 1.131886, mix_ce: 0.206116
iteration 230: loss 0.683851, mix_dice: 1.121606, mix_ce: 0.246096
iteration 231: loss 0.633414, mix_dice: 1.083724, mix_ce: 0.183105


 57%|██████████████████▋              | 21/37 [00:24<00:30,  1.93s/it]

iteration 232: loss 0.770835, mix_dice: 1.263961, mix_ce: 0.277709
iteration 233: loss 0.711321, mix_dice: 1.159066, mix_ce: 0.263576
iteration 234: loss 0.619768, mix_dice: 1.053815, mix_ce: 0.185721
iteration 235: loss 0.588748, mix_dice: 0.971972, mix_ce: 0.205525
iteration 236: loss 0.639669, mix_dice: 1.090240, mix_ce: 0.189098
iteration 237: loss 0.634480, mix_dice: 1.030855, mix_ce: 0.238105
iteration 238: loss 0.609516, mix_dice: 1.026533, mix_ce: 0.192499
iteration 239: loss 0.624184, mix_dice: 1.040295, mix_ce: 0.208072
iteration 240: loss 0.665462, mix_dice: 1.083720, mix_ce: 0.247203
iteration 241: loss 0.597600, mix_dice: 1.017393, mix_ce: 0.177807
iteration 242: loss 0.618445, mix_dice: 1.039620, mix_ce: 0.197269


 59%|███████████████████▌             | 22/37 [00:25<00:24,  1.61s/it]

iteration 243: loss 0.603922, mix_dice: 1.021448, mix_ce: 0.186396
iteration 244: loss 0.716615, mix_dice: 1.187838, mix_ce: 0.245392
iteration 245: loss 0.626781, mix_dice: 0.919361, mix_ce: 0.334200
iteration 246: loss 0.606676, mix_dice: 1.038894, mix_ce: 0.174458
iteration 247: loss 0.714754, mix_dice: 1.159701, mix_ce: 0.269807
iteration 248: loss 0.599778, mix_dice: 1.033644, mix_ce: 0.165912
iteration 249: loss 0.689043, mix_dice: 1.138900, mix_ce: 0.239186
iteration 250: loss 0.683673, mix_dice: 1.131893, mix_ce: 0.235453
iteration 251: loss 0.537670, mix_dice: 0.877016, mix_ce: 0.198323
iteration 252: loss 0.656570, mix_dice: 1.090225, mix_ce: 0.222915
iteration 253: loss 0.497573, mix_dice: 0.872782, mix_ce: 0.122363


 62%|████████████████████▌            | 23/37 [00:26<00:19,  1.37s/it]

iteration 254: loss 0.651732, mix_dice: 1.105172, mix_ce: 0.198293
iteration 255: loss 0.641149, mix_dice: 1.120320, mix_ce: 0.161977
iteration 256: loss 0.599569, mix_dice: 1.003698, mix_ce: 0.195440
iteration 257: loss 0.620146, mix_dice: 0.987503, mix_ce: 0.252789
iteration 258: loss 0.666496, mix_dice: 1.105942, mix_ce: 0.227050
iteration 259: loss 0.652545, mix_dice: 1.096372, mix_ce: 0.208718
iteration 260: loss 0.572896, mix_dice: 0.923835, mix_ce: 0.221957
iteration 261: loss 0.750520, mix_dice: 1.191664, mix_ce: 0.309376
iteration 262: loss 0.526531, mix_dice: 0.916470, mix_ce: 0.136592
iteration 263: loss 0.669693, mix_dice: 1.157072, mix_ce: 0.182314
iteration 264: loss 0.686721, mix_dice: 1.098877, mix_ce: 0.274566


 65%|█████████████████████▍           | 24/37 [00:26<00:16,  1.23s/it]

iteration 265: loss 0.668958, mix_dice: 1.117826, mix_ce: 0.220091
iteration 266: loss 0.521351, mix_dice: 0.903024, mix_ce: 0.139678
iteration 267: loss 0.551781, mix_dice: 0.971388, mix_ce: 0.132174
iteration 268: loss 0.531972, mix_dice: 0.902944, mix_ce: 0.161000
iteration 269: loss 0.657560, mix_dice: 1.138029, mix_ce: 0.177091
iteration 270: loss 0.626663, mix_dice: 1.087306, mix_ce: 0.166019
iteration 271: loss 0.781779, mix_dice: 1.170869, mix_ce: 0.392689
iteration 272: loss 0.697370, mix_dice: 1.074791, mix_ce: 0.319949
iteration 273: loss 0.605411, mix_dice: 1.021350, mix_ce: 0.189472
iteration 274: loss 0.687816, mix_dice: 1.115100, mix_ce: 0.260532
iteration 275: loss 0.583047, mix_dice: 0.953306, mix_ce: 0.212787


 68%|██████████████████████▎          | 25/37 [00:27<00:13,  1.10s/it]

iteration 276: loss 0.563991, mix_dice: 0.972576, mix_ce: 0.155406
iteration 277: loss 0.730181, mix_dice: 1.164184, mix_ce: 0.296178
iteration 278: loss 0.663059, mix_dice: 1.095394, mix_ce: 0.230723
iteration 279: loss 0.716006, mix_dice: 1.187593, mix_ce: 0.244420
iteration 280: loss 0.656472, mix_dice: 1.082502, mix_ce: 0.230442
iteration 281: loss 0.608676, mix_dice: 1.013019, mix_ce: 0.204333
iteration 282: loss 0.653697, mix_dice: 1.102993, mix_ce: 0.204400
iteration 283: loss 0.687352, mix_dice: 1.149884, mix_ce: 0.224820
iteration 284: loss 0.607151, mix_dice: 1.043067, mix_ce: 0.171234
iteration 285: loss 0.588941, mix_dice: 1.009418, mix_ce: 0.168463
iteration 286: loss 0.614820, mix_dice: 1.037034, mix_ce: 0.192606


 70%|███████████████████████▏         | 26/37 [00:28<00:11,  1.05s/it]

iteration 287: loss 0.563357, mix_dice: 1.000725, mix_ce: 0.125989
iteration 288: loss 0.551422, mix_dice: 0.966031, mix_ce: 0.136813
iteration 289: loss 0.611042, mix_dice: 1.031286, mix_ce: 0.190797
iteration 290: loss 0.639935, mix_dice: 1.063554, mix_ce: 0.216315
iteration 291: loss 0.597602, mix_dice: 0.978959, mix_ce: 0.216246
iteration 292: loss 0.652325, mix_dice: 1.054965, mix_ce: 0.249685
iteration 293: loss 0.657006, mix_dice: 1.040078, mix_ce: 0.273935
iteration 294: loss 0.625901, mix_dice: 1.027646, mix_ce: 0.224156
iteration 295: loss 0.639545, mix_dice: 1.007269, mix_ce: 0.271822
iteration 296: loss 0.591205, mix_dice: 0.980072, mix_ce: 0.202338
iteration 297: loss 0.621485, mix_dice: 1.003024, mix_ce: 0.239947


 73%|████████████████████████         | 27/37 [00:29<00:10,  1.00s/it]

iteration 298: loss 0.716446, mix_dice: 1.211701, mix_ce: 0.221192
iteration 299: loss 0.614447, mix_dice: 0.993896, mix_ce: 0.234999
iteration 300: loss 0.627385, mix_dice: 1.077786, mix_ce: 0.176984
iteration 301: loss 0.481601, mix_dice: 0.817601, mix_ce: 0.145601
iteration 302: loss 0.565760, mix_dice: 0.971146, mix_ce: 0.160373
iteration 303: loss 0.683119, mix_dice: 1.123972, mix_ce: 0.242266
iteration 304: loss 0.581910, mix_dice: 0.947117, mix_ce: 0.216704
iteration 305: loss 0.570369, mix_dice: 0.985679, mix_ce: 0.155060
iteration 306: loss 0.553508, mix_dice: 0.898399, mix_ce: 0.208618
iteration 307: loss 0.550860, mix_dice: 0.948626, mix_ce: 0.153094
iteration 308: loss 0.544864, mix_dice: 0.949007, mix_ce: 0.140722


 76%|████████████████████████▉        | 28/37 [00:30<00:08,  1.04it/s]

iteration 309: loss 0.603291, mix_dice: 1.035309, mix_ce: 0.171273
iteration 310: loss 0.582520, mix_dice: 1.001685, mix_ce: 0.163355
iteration 311: loss 0.619962, mix_dice: 1.047013, mix_ce: 0.192911
iteration 312: loss 0.623114, mix_dice: 1.034469, mix_ce: 0.211760
iteration 313: loss 0.524362, mix_dice: 0.909340, mix_ce: 0.139383
iteration 314: loss 0.590376, mix_dice: 1.029830, mix_ce: 0.150923
iteration 315: loss 0.567922, mix_dice: 0.989749, mix_ce: 0.146095
iteration 316: loss 0.546008, mix_dice: 0.956448, mix_ce: 0.135567
iteration 317: loss 0.591842, mix_dice: 0.985973, mix_ce: 0.197711
iteration 318: loss 0.646123, mix_dice: 1.023139, mix_ce: 0.269107
iteration 319: loss 0.687629, mix_dice: 1.135720, mix_ce: 0.239537


 78%|█████████████████████████▊       | 29/37 [00:31<00:07,  1.09it/s]

iteration 320: loss 0.530954, mix_dice: 0.923492, mix_ce: 0.138416
iteration 321: loss 0.540060, mix_dice: 0.962174, mix_ce: 0.117946
iteration 322: loss 0.544030, mix_dice: 0.937471, mix_ce: 0.150589
iteration 323: loss 0.571004, mix_dice: 0.987131, mix_ce: 0.154877
iteration 324: loss 0.568295, mix_dice: 0.999468, mix_ce: 0.137122
iteration 325: loss 0.509018, mix_dice: 0.929737, mix_ce: 0.088300
iteration 326: loss 0.628607, mix_dice: 1.017078, mix_ce: 0.240137
iteration 327: loss 0.575202, mix_dice: 0.969285, mix_ce: 0.181118
iteration 328: loss 0.588407, mix_dice: 0.994937, mix_ce: 0.181877
iteration 329: loss 0.553669, mix_dice: 0.938991, mix_ce: 0.168348
iteration 330: loss 0.517586, mix_dice: 0.928846, mix_ce: 0.106325


 81%|██████████████████████████▊      | 30/37 [00:32<00:06,  1.11it/s]

iteration 331: loss 0.591406, mix_dice: 1.033285, mix_ce: 0.149527
iteration 332: loss 0.589741, mix_dice: 0.986563, mix_ce: 0.192918
iteration 333: loss 0.569340, mix_dice: 0.966839, mix_ce: 0.171841
iteration 334: loss 0.571570, mix_dice: 1.005221, mix_ce: 0.137918
iteration 335: loss 0.576931, mix_dice: 0.987539, mix_ce: 0.166322
iteration 336: loss 0.522912, mix_dice: 0.925878, mix_ce: 0.119947
iteration 337: loss 0.511962, mix_dice: 0.853952, mix_ce: 0.169972
iteration 338: loss 0.443049, mix_dice: 0.750850, mix_ce: 0.135248
iteration 339: loss 0.597390, mix_dice: 1.005115, mix_ce: 0.189665
iteration 340: loss 0.587915, mix_dice: 0.992818, mix_ce: 0.183012
iteration 341: loss 0.582671, mix_dice: 1.003233, mix_ce: 0.162108


 84%|███████████████████████████▋     | 31/37 [00:33<00:05,  1.12it/s]

iteration 342: loss 0.566731, mix_dice: 0.966999, mix_ce: 0.166464
iteration 343: loss 0.614362, mix_dice: 1.033542, mix_ce: 0.195182
iteration 344: loss 0.638096, mix_dice: 1.051639, mix_ce: 0.224553
iteration 345: loss 0.522688, mix_dice: 0.922517, mix_ce: 0.122860
iteration 346: loss 0.536184, mix_dice: 0.859339, mix_ce: 0.213028
iteration 347: loss 0.604740, mix_dice: 0.967914, mix_ce: 0.241566
iteration 348: loss 0.519223, mix_dice: 0.921892, mix_ce: 0.116553
iteration 349: loss 0.469636, mix_dice: 0.842837, mix_ce: 0.096435
iteration 350: loss 0.509988, mix_dice: 0.905838, mix_ce: 0.114138
iteration 351: loss 0.475610, mix_dice: 0.806626, mix_ce: 0.144594
iteration 352: loss 0.527320, mix_dice: 0.941304, mix_ce: 0.113335


 86%|████████████████████████████▌    | 32/37 [00:33<00:04,  1.16it/s]

iteration 353: loss 0.553108, mix_dice: 0.972263, mix_ce: 0.133954
iteration 354: loss 0.554634, mix_dice: 1.006947, mix_ce: 0.102321
iteration 355: loss 0.496885, mix_dice: 0.888674, mix_ce: 0.105097
iteration 356: loss 0.508343, mix_dice: 0.915776, mix_ce: 0.100910
iteration 357: loss 0.602681, mix_dice: 1.010797, mix_ce: 0.194564
iteration 358: loss 0.645926, mix_dice: 1.059620, mix_ce: 0.232231
iteration 359: loss 0.529425, mix_dice: 0.878869, mix_ce: 0.179980
iteration 360: loss 0.544810, mix_dice: 0.939725, mix_ce: 0.149895
iteration 361: loss 0.540834, mix_dice: 0.929619, mix_ce: 0.152049
iteration 362: loss 0.519360, mix_dice: 0.857591, mix_ce: 0.181130
iteration 363: loss 0.586935, mix_dice: 1.019754, mix_ce: 0.154117


 89%|█████████████████████████████▍   | 33/37 [00:34<00:03,  1.16it/s]

iteration 364: loss 0.672165, mix_dice: 1.125167, mix_ce: 0.219162
iteration 365: loss 0.535151, mix_dice: 0.938391, mix_ce: 0.131911
iteration 366: loss 0.469307, mix_dice: 0.799857, mix_ce: 0.138757
iteration 367: loss 0.591966, mix_dice: 1.001443, mix_ce: 0.182489
iteration 368: loss 0.770559, mix_dice: 1.176610, mix_ce: 0.364507
iteration 369: loss 0.585701, mix_dice: 0.973735, mix_ce: 0.197668
iteration 370: loss 0.528869, mix_dice: 0.930254, mix_ce: 0.127484
iteration 371: loss 0.660108, mix_dice: 1.089040, mix_ce: 0.231177
iteration 372: loss 0.701976, mix_dice: 1.172050, mix_ce: 0.231902
iteration 373: loss 0.585846, mix_dice: 0.982254, mix_ce: 0.189438
iteration 374: loss 0.557597, mix_dice: 0.948007, mix_ce: 0.167186


 92%|██████████████████████████████▎  | 34/37 [00:35<00:02,  1.20it/s]

iteration 375: loss 0.643272, mix_dice: 1.117305, mix_ce: 0.169239
iteration 376: loss 0.540008, mix_dice: 0.919304, mix_ce: 0.160712
iteration 377: loss 0.549075, mix_dice: 0.909540, mix_ce: 0.188609
iteration 378: loss 0.536849, mix_dice: 0.942758, mix_ce: 0.130941
iteration 379: loss 0.658562, mix_dice: 1.131678, mix_ce: 0.185446
iteration 380: loss 0.596741, mix_dice: 1.012467, mix_ce: 0.181015
iteration 381: loss 0.583010, mix_dice: 0.986428, mix_ce: 0.179592
iteration 382: loss 0.564834, mix_dice: 0.971981, mix_ce: 0.157687
iteration 383: loss 0.495803, mix_dice: 0.850260, mix_ce: 0.141347
iteration 384: loss 0.566004, mix_dice: 0.978886, mix_ce: 0.153122
iteration 385: loss 0.593005, mix_dice: 0.966699, mix_ce: 0.219312


 95%|███████████████████████████████▏ | 35/37 [00:36<00:01,  1.20it/s]

iteration 386: loss 0.509992, mix_dice: 0.850936, mix_ce: 0.169047
iteration 387: loss 0.538161, mix_dice: 0.867620, mix_ce: 0.208702
iteration 388: loss 0.615012, mix_dice: 1.038119, mix_ce: 0.191905
iteration 389: loss 0.680681, mix_dice: 1.125991, mix_ce: 0.235370
iteration 390: loss 0.653950, mix_dice: 1.088577, mix_ce: 0.219322
iteration 391: loss 0.551948, mix_dice: 0.969487, mix_ce: 0.134409
iteration 392: loss 0.602007, mix_dice: 1.035932, mix_ce: 0.168082
iteration 393: loss 0.444566, mix_dice: 0.773048, mix_ce: 0.116084
iteration 394: loss 0.439483, mix_dice: 0.643160, mix_ce: 0.235805
iteration 395: loss 0.734568, mix_dice: 1.276594, mix_ce: 0.192542
iteration 396: loss 0.670772, mix_dice: 1.087008, mix_ce: 0.254537


 97%|████████████████████████████████ | 36/37 [00:37<00:00,  1.23it/s]

iteration 397: loss 0.561461, mix_dice: 1.003679, mix_ce: 0.119243
iteration 398: loss 0.609728, mix_dice: 1.020160, mix_ce: 0.199295
iteration 399: loss 0.542429, mix_dice: 0.962678, mix_ce: 0.122179
iteration 400: loss 0.469323, mix_dice: 0.812316, mix_ce: 0.126331
iteration 400 : mean dice : 0.550409


 97%|████████████████████████████████ | 36/37 [00:44<00:01,  1.24s/it]

Total slices: 1312
Total slices: 20
Total slices is: 1312, labeled slices is:136
Loaded from ./model/BCP/ACDC_BCP_7_labeled/pretrain/unet_best_model.pth
Loaded from ./model/BCP/ACDC_BCP_7_labeled/pretrain/unet_best_model.pth
Start self_training
Start self_training
11 iterations per epoch
11 iterations per epoch



  state = torch.load(str(path))
  state = torch.load(str(path))
  0%|                                           | 0/1 [00:00<?, ?it/s]

iteration 1: loss: 0.717569, mix_dice: 1.178169, mix_ce: 0.256968
iteration 1: loss: 0.717569, mix_dice: 1.178169, mix_ce: 0.256968


  return torch.Tensor(batch_list).cuda()


iteration 2: loss: 0.908124, mix_dice: 1.512403, mix_ce: 0.303845
iteration 2: loss: 0.908124, mix_dice: 1.512403, mix_ce: 0.303845
iteration 3: loss: 0.777800, mix_dice: 1.337923, mix_ce: 0.217677
iteration 3: loss: 0.777800, mix_dice: 1.337923, mix_ce: 0.217677
iteration 4: loss: 0.908684, mix_dice: 1.538955, mix_ce: 0.278413
iteration 4: loss: 0.908684, mix_dice: 1.538955, mix_ce: 0.278413
iteration 5: loss: 0.840437, mix_dice: 1.450593, mix_ce: 0.230281
iteration 5: loss: 0.840437, mix_dice: 1.450593, mix_ce: 0.230281
iteration 6: loss: 0.788802, mix_dice: 1.386643, mix_ce: 0.190961
iteration 6: loss: 0.788802, mix_dice: 1.386643, mix_ce: 0.190961
iteration 7: loss: 0.751624, mix_dice: 1.317435, mix_ce: 0.185813
iteration 7: loss: 0.751624, mix_dice: 1.317435, mix_ce: 0.185813
iteration 8: loss: 0.886189, mix_dice: 1.543647, mix_ce: 0.228730
iteration 8: loss: 0.886189, mix_dice: 1.543647, mix_ce: 0.228730
iteration 9: loss: 0.835229, mix_dice: 1.451798, mix_ce: 0.218660
iteration 

  0%|                                           | 0/1 [00:02<?, ?it/s]
