In [83]:
!pip install medpy



In [84]:
import numpy as np 
import pandas as pd 
import matplotlib.pyplot as plt 
import h5py
import os 
from scipy.ndimage import zoom, rotate
import random 
# torch 
import torch 
from torch.utils.data import Dataset 
from torch.utils.data import DataLoader
from torch.utils.data import Sampler
import torch.optim as optim 
import sys

import itertools
import torchvision.transforms as transforms 
import torch.nn as nn 
import torch.nn.functional as F 
from medpy import metric 
from skimage.measure import label 
import logging
from tqdm import tqdm
import torch.backends.cudnn as cudnn
from torch.nn import CrossEntropyLoss
from tensorboardX import SummaryWriter

#### 1. ACDC Dataset

In [85]:
class BaseDataset(Dataset): 
    def __init__(self, root_path, split= 'train', transform= None, num= None): 
        """
        Use to load name of file.list 

        """
        self.root_path = root_path
        self.split = split 
        self.transform = transform
        self.sample_list = []

        if split == 'train': 
            file_name = os.path.join(self.root_path, 'train_slices.list')
            with open(file_name, 'r') as file: 
                self.sample_list = file.readlines() 
            
            self.sample_list = [item.replace('\n', '') for item in self.sample_list]
            # print(self.sample_list)
        elif split == 'val': 
            file_name = os.path.join(self.root_path, 'val.list')
            with open(file_name, 'r') as file: 
                self.sample_list = file.readlines() 
            
            self.sample_list = [item.replace('\n','') for item in self.sample_list]
            # print(self.sample_list)
        
        else: 
            raise ValueError(f'Mode: {self.split} is not supported')
        

        if isinstance(num, int): 
            self.sample_list = self.sample_list[:num]
        
        print(f'Total slices: {len(self.sample_list)}')
        

    def __len__(self): 
        return len(self.sample_list)

    def __getitem__(self, index):
        case = self.sample_list[index]

        # open the file.h5 
        if self.split == 'train': 
            file_path = os.path.join(self.root_path, f'data/slices/{case}.h5')
        elif self.split == 'val': 
            file_path = os.path.join(self.root_path, f'data/{case}.h5')
        
        h5f = h5py.File(file_path, 'r')
        image = h5f['image'][:]
        label = h5f['label'][:]

        sample = {'image': image, 'label': label}
        if self.split == 'train' and self.transform is not None: 
            sample = self.transform(sample)
        
        sample['case'] = case 
        return sample 

In [86]:
def random_rot_flip(image, label): 
    """
    Random rotate and Random flip 
    """
    
    # Random rotate
    k = np.random.randint(0, 4) 
    image = np.rot90(image, k)
    label = np.rot90(label, k)

    # Random flip 
    axis = np.random.randint(0, 2)
    image = np.flip(image, axis).copy() 
    label = np.flip(label, axis).copy() 

    return image, label 

def random_rotate(image, label):
    angle = np.random.randint(-20, 20) 
    image = rotate(image, angle, order= 0, reshape= False)
    label = rotate(label,angle, order=0, reshape= False )
    return image, label

class RandomGenerator: 
    def __init__(self, output_size): 
        self.output_size = output_size
    
    def __call__(self, sample):
        image, label = sample['image'], sample['label']
        if np.random.random() > 0.5: 
            image, label = random_rot_flip(image, label)
        
        if np.random.random() > 0.5: 
            image, label = random_rotate(image, label) 
        
        # Zoom image to -> [256,256]
        x,y = image.shape
        image = zoom(image, (self.output_size[0] / x, self.output_size[1] / y), order= 0)
        label = zoom(label, (self.output_size[0] /x , self.output_size[1] / y), order= 0)

        # Convert to pytorch 
        imageTensor = torch.from_numpy(image.astype(np.float32)).unsqueeze(0) # image.shape = (1, H, W)
        labelTensor = torch.from_numpy(label.astype(np.uint8)) # label.shape = (H, W)
        sample = {'image': imageTensor, 'label': labelTensor}
        
        return sample
    

In [87]:
def iterate_once(indices): 
    """
    Permutate the iterable once 
    (permutate the labeled_idxs once)
    """
    return np.random.permutation(indices) 

def iterate_externally(indices): 
    """
    Create an infinite iterator that repeatedly permutes the indices.
    ( permutate the unlabeled_idxs to make different)
    """
    def infinite_shuffles(): 
        while True: 
            yield np.random.permutation(indices)
            
    return itertools.chain.from_iterable(infinite_shuffles())

def grouper(iterable, n): 
    args = [iter(iterable)] * n 
    return zip(*args)

class TwoStreamBatchSampler(Sampler): 
    def __init__(self, primary_indicies, secondary_indicies, batchsize, secondary_batchsize): 
        self.primary_indicies = primary_indicies
        self.secondary_indicies = secondary_indicies
        self.primary_batchsize = batchsize - secondary_batchsize
        self.secondary_batchsize = secondary_batchsize

        assert len(self.primary_indicies) >= self.primary_batchsize > 0 
        assert len(self.secondary_indicies) >= self.secondary_batchsize > 0 

    def __iter__(self):
        primary_iter = iterate_once(self.primary_indicies)
        secondary_iter = iterate_externally(self.secondary_indicies)

        return (
            primary_batch + secondary_batch
            for (primary_batch, secondary_batch) 
            in zip(grouper(primary_iter, self.primary_batchsize),
                   grouper(secondary_iter, self.secondary_batchsize))
        )

    def __len__(self): 
        return len(self.primary_indicies) // self.primary_batchsize

In [88]:
# Split the data 
def patients_to_slices(dataset, patients_num): 
    ref_dict = {} 
    if "ACDC" in dataset: 
        ref_dict = {'1': 32, '3': 68, '7': 136, '14': 256, '21': 396, '28': 512, '35': 664, '70': 1312}
    else:
        print('Error')
    
    return ref_dict[str(patients_num)]

#### 2. Loss 

In [89]:
class DiceLoss(nn.Module): 
    def __init__(self, n_classes): 
        super(DiceLoss, self).__init__() 
        self.n_classes = n_classes
    
    def _one_hot_encoder(self, input_tensor): # torch.nn.functional.one_hot()
        """
        Apply one-hot encoder for input_tensor 
        Parameters: 
            - input_tensor.shape = (batchsize,1, H, W), the target image
        """
        tensor_list = [] 
        for i in range(self.n_classes): 
            temp_prob = input_tensor == i * torch.ones_like(input_tensor)
            tensor_list.append(temp_prob)
        output_tensor = torch.cat(tensor_list, dim= 1)
        return output_tensor.float() 
    
    def _dice_loss(self, score, target): 
        target = target.float() 
        smooth = 1e-10 
        
        intersection = torch.sum(score * target)
        union = torch.sum(score* score) + torch.sum(target*target)
        dice = ( 2*intersection + smooth) / (union + smooth)
        loss = 1 - dice 
        return loss 
    
    def _dice_mask_loss(self, score, target, mask): 
        target = target.float() 
        mask = mask.float() 
        smooth = 1e-10 

        intersection = torch.sum(score * target * mask)
        union = torch.sum(score * score * mask ) + torch.sum(target * target * mask)
        dice = (2*intersection + smooth) / (union + smooth)
        loss = 1 - dice 
        return loss 

    def forward(self, inputs, target, mask= None, weight= None, softmax= False): 
        if softmax: 
            inputs = torch.softmax(inputs, dim= 1) 
        
        target = self._one_hot_encoder(target)

        # weight 
        if weight is  None: 
            weight = [1] * self.n_classes
        
        assert inputs.size() == target.size(), 'predict and target shape do not match'
        class_wise_dice = [] 
        loss = 0.0 
        if mask is not None: 
            mask = mask.repeat(1, self.n_classes, 1, 1).type(torch.float32)
            for i in range(0, self.n_classes): 
                dice = self._dice_mask_loss(inputs[:, i], target[:, i], mask[:, i])
                class_wise_dice.append( 1.0 - dice.item())
                loss += dice * weight[i]

        else: 
            for i in range(0, self.n_classes): 
                dice = self._dice_loss(inputs[:, i], target[:, i]) 
                class_wise_dice.append(1.0 - dice.item())
                loss += dice * weight[i] 
        
        return loss / self.n_classes

In [90]:
dice_loss = DiceLoss(n_classes= 4)

In [91]:
def mix_loss(output, img_l, patch_l, mask, l_weight=1.0, u_weight=0.5, unlab=False):
    CE = nn.CrossEntropyLoss(reduction='none')
    img_l, patch_l = img_l.type(torch.int64), patch_l.type(torch.int64)
    output_soft = F.softmax(output, dim=1)
    image_weight, patch_weight = l_weight, u_weight
    if unlab:
        image_weight, patch_weight = u_weight, l_weight
    patch_mask = 1 - mask
    loss_dice = dice_loss(output_soft, img_l.unsqueeze(1), mask.unsqueeze(1)) * image_weight
    loss_dice += dice_loss(output_soft, patch_l.unsqueeze(1), patch_mask.unsqueeze(1)) * patch_weight
    loss_ce = image_weight * (CE(output, img_l) * mask).sum() / (mask.sum() + 1e-16) 
    loss_ce += patch_weight * (CE(output, patch_l) * patch_mask).sum() / (patch_mask.sum() + 1e-16)#loss = loss_ce
    return loss_dice, loss_ce

#### 3.Valid 


In [92]:
def calculate_metric_percase(pred, gt):
    pred[pred > 0] = 1
    gt[gt > 0] = 1
    if pred.sum() > 0:
        dice = metric.binary.dc(pred, gt)
        hd95 = metric.binary.hd95(pred, gt)
        return dice, hd95
    else:
        return 0, 0


def test_single_volume(image, label, model, classes, patch_size=[256, 256]):
    image, label = image.squeeze(0).cpu().detach().numpy(), label.squeeze(0).cpu().detach().numpy()
    prediction = np.zeros_like(label)
    for ind in range(image.shape[0]):
        slice = image[ind, :, :]
        x, y = slice.shape[0], slice.shape[1]
        slice = zoom(slice, (patch_size[0] / x, patch_size[1] / y), order=0)
        input = torch.from_numpy(slice).unsqueeze(0).unsqueeze(0).float().cuda()
        model.eval()
        with torch.no_grad():
            output = model(input)
            if len(output)>1:
                output = output[0]
            out = torch.argmax(torch.softmax(output, dim=1), dim=1).squeeze(0)
            out = out.cpu().detach().numpy()
            pred = zoom(out, (x / patch_size[0], y / patch_size[1]), order=0)
            prediction[ind] = pred
    metric_list = []
    for i in range(1, classes):
        metric_list.append(calculate_metric_percase(prediction == i, label == i))
    return metric_list


#### 3. UNet

In [93]:
from __future__ import division, print_function

import numpy as np
import torch
import torch.nn as nn
import pdb
from torch.nn import functional as F
from torch.distributions.uniform import Uniform


class ConvBlock(nn.Module):
    """two convolution layers with batch norm and leaky relu"""
    def __init__(self, in_channels, out_channels, dropout_p):
        super(ConvBlock, self).__init__()
        self.conv_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(),
            nn.Dropout(dropout_p),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU()
        )

    def forward(self, x):
        return self.conv_conv(x)

class DownBlock(nn.Module):
    """Downsampling followed by ConvBlock"""
    def __init__(self, in_channels, out_channels, dropout_p):
        super(DownBlock, self).__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            ConvBlock(in_channels, out_channels, dropout_p)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class UpBlock(nn.Module):
    """Upsampling followed by ConvBlock"""
    def __init__(self, in_channels1, in_channels2, out_channels, dropout_p):
        super(UpBlock, self).__init__()
        self.conv1x1 = nn.Conv2d(in_channels1, in_channels2, kernel_size=1)
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv = ConvBlock(in_channels2 * 2, out_channels, dropout_p)

    def forward(self, x1, x2):
        x1 = self.conv1x1(x1)
        x1 = self.up(x1)
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class Encoder(nn.Module):
    def __init__(self, params):
        super(Encoder, self).__init__()
        self.params = params
        self.in_chns = self.params['in_chns']
        self.ft_chns = self.params['feature_chns']
        self.n_class = self.params['class_num']
        self.dropout = self.params['dropout']
        assert (len(self.ft_chns) == 5)
        self.in_conv = ConvBlock(
            self.in_chns, self.ft_chns[0], self.dropout[0])
        self.down1 = DownBlock(
            self.ft_chns[0], self.ft_chns[1], self.dropout[1])
        self.down2 = DownBlock(
            self.ft_chns[1], self.ft_chns[2], self.dropout[2])
        self.down3 = DownBlock(
            self.ft_chns[2], self.ft_chns[3], self.dropout[3])
        self.down4 = DownBlock(
            self.ft_chns[3], self.ft_chns[4], self.dropout[4])

    def forward(self, x):
        x0 = self.in_conv(x)
        x1 = self.down1(x0)
        x2 = self.down2(x1)
        x3 = self.down3(x2)
        x4 = self.down4(x3)
        return [x0, x1, x2, x3, x4]

class Decoder(nn.Module):
    def __init__(self, params):
        super(Decoder, self).__init__()
        self.params = params
        self.in_chns = self.params['in_chns']
        self.ft_chns = self.params['feature_chns']
        self.n_class = self.params['class_num']
        assert (len(self.ft_chns) == 5)

        self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], dropout_p=0.0)
        self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], dropout_p=0.0)
        self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], dropout_p=0.0)
        self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], dropout_p=0.0)

        self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, kernel_size=3, padding=1)

    def forward(self, feature):
        x0 = feature[0]
        x1 = feature[1]
        x2 = feature[2]
        x3 = feature[3]
        x4 = feature[4]

        x = self.up1(x4, x3)
        x = self.up2(x, x2)
        x = self.up3(x, x1)
        x_last = self.up4(x, x0)
        output = self.out_conv(x_last)
        return output, x_last

class Decoder_tsne(nn.Module):
    def __init__(self, params):
        super(Decoder_tsne, self).__init__()
        self.params = params
        self.in_chns = self.params['in_chns']
        self.ft_chns = self.params['feature_chns']
        self.n_class = self.params['class_num']
        assert (len(self.ft_chns) == 5)

        self.up1 = UpBlock(self.ft_chns[4], self.ft_chns[3], self.ft_chns[3], dropout_p=0.0)
        self.up2 = UpBlock(self.ft_chns[3], self.ft_chns[2], self.ft_chns[2], dropout_p=0.0)
        self.up3 = UpBlock(self.ft_chns[2], self.ft_chns[1], self.ft_chns[1], dropout_p=0.0)
        self.up4 = UpBlock(self.ft_chns[1], self.ft_chns[0], self.ft_chns[0], dropout_p=0.0)

        self.out_conv = nn.Conv2d(self.ft_chns[0], self.n_class, kernel_size=3, padding=1)

    def forward(self, feature):
        x0 = feature[0]
        x1 = feature[1]
        x2 = feature[2]
        x3 = feature[3]
        x4 = feature[4]

        x5 = self.up1(x4, x3) # 1, 128, 32, 32
        x6 = self.up2(x5, x2) # 1, 64, 64, 64
        x7 = self.up3(x6, x1) # 1, 32, 128, 128
        x_last = self.up4(x7, x0) # 1, 16, 256, 256
        output = self.out_conv(x_last)
        return output, x_last
    
class UNet(nn.Module):
    def __init__(self, in_chns, class_num):
        super(UNet, self).__init__()

        params = {'in_chns': in_chns,
                  'feature_chns': [16, 32, 64, 128, 256],
                  'dropout': [0.05, 0.1, 0.2, 0.3, 0.5],
                  'class_num': class_num,
                  'acti_func': 'relu'}

        self.encoder = Encoder(params)
        self.decoder = Decoder(params)
        dim_in = 16
        feat_dim = 32
        self.projection_head = nn.Sequential(
            nn.Linear(dim_in, feat_dim),
            nn.BatchNorm1d(feat_dim),
            nn.ReLU(inplace=True),
            nn.Linear(feat_dim, feat_dim)
        )
        self.prediction_head = nn.Sequential(
            nn.Linear(feat_dim, feat_dim),
            nn.BatchNorm1d(feat_dim),
            nn.ReLU(inplace=True),
            nn.Linear(feat_dim, feat_dim)
        )
        for class_c in range(4):
            selector = nn.Sequential(
                nn.Linear(feat_dim, feat_dim),
                nn.BatchNorm1d(feat_dim),
                nn.LeakyReLU(negative_slope=0.2, inplace=True),
                nn.Linear(feat_dim, 1)
            )
            self.__setattr__('contrastive_class_selector_' + str(class_c), selector)

        for class_c in range(4):
            selector = nn.Sequential(
                nn.Linear(feat_dim, feat_dim),
                nn.BatchNorm1d(feat_dim),
                nn.LeakyReLU(negative_slope=0.2, inplace=True),
                nn.Linear(feat_dim, 1)
            )
            self.__setattr__('contrastive_class_selector_memory' + str(class_c), selector)
            
    def forward_projection_head(self, features):
        return self.projection_head(features)

    def forward_prediction_head(self, features):
        return self.prediction_head(features)

    def forward(self, x):
        feature = self.encoder(x)
        output, features = self.decoder(feature)
        return output, features

class UNet_2d(nn.Module):
    def __init__(self, in_chns, class_num):
        super(UNet_2d, self).__init__()

        params = {'in_chns': in_chns,
                  'feature_chns': [16, 32, 64, 128, 256],
                  'dropout': [0.05, 0.1, 0.2, 0.3, 0.5],
                  'class_num': class_num,
                  'acti_func': 'relu'}

        self.encoder = Encoder(params)
        self.decoder = Decoder(params)
        dim_in = 16
        feat_dim = 32
        self.projection_head = nn.Sequential(
            nn.Linear(dim_in, feat_dim),
            nn.BatchNorm1d(feat_dim),
            nn.ReLU(inplace=True),
            nn.Linear(feat_dim, feat_dim)
        )
        self.prediction_head = nn.Sequential(
            nn.Linear(feat_dim, feat_dim),
            nn.BatchNorm1d(feat_dim),
            nn.ReLU(inplace=True),
            nn.Linear(feat_dim, feat_dim)
        )
        for class_c in range(4):
            selector = nn.Sequential(
                nn.Linear(feat_dim, feat_dim),
                nn.BatchNorm1d(feat_dim),
                nn.LeakyReLU(negative_slope=0.2, inplace=True),
                nn.Linear(feat_dim, 1)
            )
            self.__setattr__('contrastive_class_selector_' + str(class_c), selector)

        for class_c in range(4):
            selector = nn.Sequential(
                nn.Linear(feat_dim, feat_dim),
                nn.BatchNorm1d(feat_dim),
                nn.LeakyReLU(negative_slope=0.2, inplace=True),
                nn.Linear(feat_dim, 1)
            )
            self.__setattr__('contrastive_class_selector_memory' + str(class_c), selector)
            
    def forward_projection_head(self, features):
        return self.projection_head(features)
        #return self.decoder(features)

    def forward_prediction_head(self, features):
        return self.prediction_head(features)

    def forward(self, x):
        feature = self.encoder(x)
        output, features = self.decoder(feature)
        return output

In [94]:
import torch.nn as nn

def net_factory(net_type="unet", in_chns=1, class_num=2, mode = "train", tsne=0):
    if net_type == "unet" and mode == "train":
        net = UNet(in_chns=in_chns, class_num=class_num).cuda()
    return net

def BCP_net(in_chns=1, class_num=4, ema=False):
    net = UNet_2d(in_chns=in_chns, class_num=class_num).cuda()
    if ema:
        for param in net.parameters():
            param.detach_()
    return net


#### 4.Training 

In [95]:
def get_ACDC_2DLargestCC(segmentation):
    batch_list = []
    N = segmentation.shape[0]
    for i in range(0, N):
        class_list = []
        for c in range(1, 4):
            temp_seg = segmentation[i] #== c *  torch.ones_like(segmentation[i])
            temp_prob = torch.zeros_like(temp_seg)
            temp_prob[temp_seg == c] = 1
            temp_prob = temp_prob.detach().cpu().numpy()
            labels = label(temp_prob)          
            if labels.max() != 0:
                largestCC = labels == np.argmax(np.bincount(labels.flat)[1:])+1
                class_list.append(largestCC * c)
            else:
                class_list.append(temp_prob)
        
        n_batch = class_list[0] + class_list[1] + class_list[2]
        batch_list.append(n_batch)

    return torch.Tensor(batch_list).cuda()
    
def get_ACDC_masks(output, nms=0):
    probs = F.softmax(output, dim=1)
    _, probs = torch.max(probs, dim=1)
    if nms == 1:
        probs = get_ACDC_2DLargestCC(probs)      
    return probs

In [96]:
def save_net_opt(net, optimizer, path):
    state = {
        'net': net.state_dict(),
        'optim': optimizer.state_dict()
    }
    torch.save(state, str(path))

def load_net(net, path):
    state = torch.load(str(path))
    net.load_state_dict(state['net']) 

def load_net_opt(net, optimizer, path): 
    state = torch.load(str(path))
    net.load_state_dict(state['net'])
    optimizer.load_state_dict(state['optim'])

In [97]:
def generate_mask(img):
    batch_size, channel, img_x, img_y = img.shape[0], img.shape[1], img.shape[2], img.shape[3]
    loss_mask = torch.ones(batch_size, img_x, img_y).cuda()
    mask = torch.ones(img_x, img_y).cuda()
    patch_x, patch_y = int(img_x*2/3), int(img_y*2/3)
    w = np.random.randint(0, img_x - patch_x)
    h = np.random.randint(0, img_y - patch_y)
    mask[w:w+patch_x, h:h+patch_y] = 0
    loss_mask[:, w:w+patch_x, h:h+patch_y] = 0
    return mask.long(), loss_mask.long()

In [98]:
def sigmoid_rampup(current, rampup_length):
    if rampup_length == 0: 
        return 1.0 
    else:
        current = np.clip(current, 0, rampup_length)
        phase = 1 - (current / rampup_length)
        return float(np.exp(-5 * phase * phase))

In [99]:
# Mean-Teacher compomnent 
def get_current_consistency_weight(epoch, args): 
    return 5 * args.consistency + sigmoid_rampup(epoch, args.consistency_rampup)

def update_model_ema(model, ema_model, alpha): 
    model_state = model.state_dict() 
    model_ema_state = ema_model.state_dict()
    new_dict = {}

    for key in model_state:
        new_dict[key] = alpha * model_ema_state[key] + (1 - alpha) * model_state[key]

    ema_model.load_state_dict(new_dict)

In [100]:
# Configuration
def pre_train(args, snapshot_path):
    base_lr = args.base_lr
    num_classes = args.num_classes
    max_iterations = args.pretrain_iterations
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 
    pre_trained_model = os.path.join(snapshot_path, '{}_best_model.pth'.format(args.model))
    labeled_sub_bs, unlabeled_sub_bs = int(args.labeled_bs / 2), int((args.batch_size - args.labeled_bs)/2)

    def worker_init_fn(worker_id):
        random.seed(args.seed + worker_id)

    # load dataset 
    db_train = BaseDataset(root_path = args.root_dir, split= 'train', 
                        num= None, 
                        transform= transforms.Compose([RandomGenerator(args.patch_size)]))
    
    db_val = BaseDataset(root_path= args.root_dir, split= 'val')

    total_slices = len(db_train)
    labeled_slices = patients_to_slices(args.root_dir, args.label_num)
    print(f'Total slice is {total_slices}, Labeled slice is {labeled_slices}')

    # Create batch_sampler 
    labeled_idxs = list(range(0, labeled_slices))
    unlabeled_idxs = list(range(labeled_slices, total_slices))
    batch_sampler = TwoStreamBatchSampler(labeled_idxs, unlabeled_idxs, args.batch_size, args.batch_size - args.labeled_bs)

    # Create dataloader 
    trainloader = DataLoader(db_train, batch_sampler= batch_sampler, num_workers= 4, pin_memory= True, worker_init_fn= worker_init_fn)
    valloader = DataLoader(db_val, batch_size= 1, shuffle= False, num_workers=1)

    # Define model 
    model = BCP_net(in_chns=1, class_num= num_classes)
    optimizer = optim.SGD(model.parameters(), lr= base_lr, momentum= 0.9, weight_decay= 0.0001)

    writer = SummaryWriter(snapshot_path + '/log')
    logging.info('Start pre-training')
    logging.info(f'{len(trainloader)} iterations per epoch')

    # training process
    model.train() 
    iter_num = 0 
    max_epoch = max_iterations // len(trainloader) + 1 
    best_performance = 0.0 
    best_hd = 100.0
    iterator = tqdm(range(max_epoch), ncols= 70)
    
    for _ in iterator: 
        for _, sampled_batch in enumerate(trainloader): 
            volume_batch, label_batch = sampled_batch['image'], sampled_batch['label']
            volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda() 

            img_a, img_b = volume_batch[: labeled_sub_bs], volume_batch[labeled_sub_bs : args.labeled_bs]
            lab_a, lab_b = label_batch[: labeled_sub_bs], label_batch[labeled_sub_bs : args.labeled_bs]
            img_mask, loss_mask = generate_mask(img_a)
            gt_mixl = lab_a * img_mask + lab_b * ( 1- img_mask)

            #-- original 
            net_input = img_a * img_mask + img_b * ( 1 - img_mask)  
            out_mixl = model(net_input)
            loss_dice, loss_ce = mix_loss(out_mixl, lab_a, lab_b, loss_mask,u_weight= 1.0, unlab= True )
            loss = (loss_dice + loss_ce )/2 

            optimizer.zero_grad() 
            loss.backward() 
            optimizer.step() 

            iter_num += 1 

            writer.add_scalar('info/total_loss', loss, iter_num)
            writer.add_scalar('info/mix_dice', loss_dice,iter_num )
            writer.add_scalar('info/mix_ce', loss_ce, iter_num)

            logging.info('iteration %d: loss %f, mix_dice: %f, mix_ce: %f'%(iter_num, loss, loss_dice, loss_ce))
            if iter_num % 20 == 0: 
                image = net_input[1, 0:1, :, :]
                writer.add_image('pre_train/Mixed_Image', image, iter_num)
                outputs = torch.argmax(torch.softmax(out_mixl, dim=1), dim=1, keepdim= True)
                writer.add_image('pre_train/Mixed_Prediction', outputs[1, ...]*50, iter_num)
                labs = gt_mixl[1, ...].unsqueeze(0) * 50 
                writer.add_image('pre_train/Mixed_GroundTruth', labs, iter_num)
            
            # Evaluate after 200 epoch ! 
            if iter_num > 0 and iter_num % 200 == 0: 
                model.eval() 
                metric_list = 0.0 
                for _, sampled_batch in enumerate(valloader):
                    metric_i = test_single_volume(sampled_batch['image'], sampled_batch['label'], model, classes= num_classes)
                    metric_list += np.array(metric_i)
                
                metric_list = metric_list / len(db_val)
                
    
                for class_i in range(num_classes - 1 ): 
                    writer.add_scalar('info/val_{}_dice'.format(class_i + 1), metric_list[class_i, 0], iter_num)
                    writer.add_scalar('infor/val_{}_hd'.format(class_i + 1), metric_list[class_i, 1], iter_num)
                
                performance = np.mean(metric_list, axis=0)[0]
                writer.add_scalar('info/val_mean_dice', performance, iter_num)
            
                if performance > best_performance: 
                    best_performance = performance
                    save_model_path = os.path.join(snapshot_path, 'iter_{}_dice_{}.pth'.format(iter_num, round(best_performance,4)))
                    save_best_path = os.path.join(snapshot_path, '{}_best_model.pth'.format(args.model))
                    save_net_opt(model, optimizer, save_model_path)
                    save_net_opt(model, optimizer, save_best_path)
                
                logging.info('iteration %d : mean dice : %f'%(iter_num, performance))
                model.train() 
            
            if iter_num >= max_iterations: 
                break 
        
        if iter_num >= max_iterations: 
            iterator.close() 
            break 
    
    writer.close()

In [101]:
def self_train(args ,pre_snapshot_path, snapshot_path):
    base_lr = args.base_lr
    num_classes = args.num_classes
    max_iterations = args.selftrain_iterations
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    pre_trained_model = os.path.join(pre_snapshot_path,'{}_best_model.pth'.format(args.model))
    labeled_sub_bs, unlabeled_sub_bs = int(args.labeled_bs/2), int((args.batch_size-args.labeled_bs) / 2)
     
    model = BCP_net(in_chns=1, class_num=num_classes)
    ema_model = BCP_net(in_chns=1, class_num=num_classes, ema=True)

    def worker_init_fn(worker_id):
        random.seed(args.seed + worker_id)

    db_train = BaseDataset(root_path=args.root_dir,
                            split="train",
                            num=None,
                            transform=transforms.Compose([RandomGenerator(args.patch_size)]))
    db_val = BaseDataset(root_path=args.root_dir, split="val")
    total_slices = len(db_train)
    labeled_slice = patients_to_slices(args.root_dir,args.label_num)
    print("Total slices is: {}, labeled slices is:{}".format(total_slices, labeled_slice))
    labeled_idxs = list(range(0, labeled_slice))
    unlabeled_idxs = list(range(labeled_slice, total_slices))
    batch_sampler = TwoStreamBatchSampler(labeled_idxs, unlabeled_idxs, args.batch_size, args.batch_size-args.labeled_bs)

    trainloader = DataLoader(db_train, batch_sampler=batch_sampler, num_workers=4, pin_memory=True, worker_init_fn=worker_init_fn)

    valloader = DataLoader(db_val, batch_size=1, shuffle=False, num_workers=1)

    optimizer = optim.SGD(model.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0001)
    load_net(ema_model, pre_trained_model)
    load_net_opt(model, optimizer, pre_trained_model) #
    logging.info("Loaded from {}".format(pre_trained_model))

    writer = SummaryWriter(snapshot_path + '/log')
    logging.info("Start self_training")
    logging.info("{} iterations per epoch".format(len(trainloader)))

    model.train()
    ema_model.train()

    ce_loss = CrossEntropyLoss()

    iter_num = 0
    max_epoch = max_iterations // len(trainloader) + 1
    best_performance = 0.0
    best_hd = 100
    iterator = tqdm(range(max_epoch), ncols=70)
    for _ in iterator:
        for _, sampled_batch in enumerate(trainloader):
            
            volume_batch, label_batch = sampled_batch['image'], sampled_batch['label']
            volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda()

            img_a, img_b = volume_batch[:labeled_sub_bs], volume_batch[labeled_sub_bs:args.labeled_bs]
            uimg_a, uimg_b = volume_batch[args.labeled_bs:args.labeled_bs + unlabeled_sub_bs], volume_batch[args.labeled_bs + unlabeled_sub_bs:]
            ulab_a, ulab_b = label_batch[args.labeled_bs:args.labeled_bs + unlabeled_sub_bs], label_batch[args.labeled_bs + unlabeled_sub_bs:]
            lab_a, lab_b = label_batch[:labeled_sub_bs], label_batch[labeled_sub_bs:args.labeled_bs]
            with torch.no_grad():
                pre_a = ema_model(uimg_a) # pseudo label a 
                pre_b = ema_model(uimg_b) # pseudo label b 
                plab_a = get_ACDC_masks(pre_a, nms=1)
                plab_b = get_ACDC_masks(pre_b, nms=1)
                img_mask, loss_mask = generate_mask(img_a)
                unl_label = ulab_a * img_mask + lab_a * (1 - img_mask)
                l_label = lab_b * img_mask + ulab_b * (1 - img_mask)
            consistency_weight = get_current_consistency_weight(iter_num//150,args)

            net_input_unl = uimg_a * img_mask + img_a * (1 - img_mask)
            net_input_l = img_b * img_mask + uimg_b * (1 - img_mask)
            out_unl = model(net_input_unl)
            out_l = model(net_input_l)
            unl_dice, unl_ce = mix_loss(out_unl, plab_a, lab_a, loss_mask, u_weight=args.u_weight, unlab=True)
            l_dice, l_ce = mix_loss(out_l, lab_b, plab_b, loss_mask, u_weight=args.u_weight)


            loss_ce = unl_ce + l_ce 
            loss_dice = unl_dice + l_dice

            loss = (loss_dice + loss_ce) / 2            

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            iter_num += 1
            update_model_ema(model, ema_model, 0.99)

            writer.add_scalar('info/total_loss', loss, iter_num)
            writer.add_scalar('info/mix_dice', loss_dice, iter_num)
            writer.add_scalar('info/mix_ce', loss_ce, iter_num)
            writer.add_scalar('info/consistency_weight', consistency_weight, iter_num)     

            logging.info('iteration %d: loss: %f, mix_dice: %f, mix_ce: %f'%(iter_num, loss, loss_dice, loss_ce))
                
            if iter_num % 20 == 0:
                image = net_input_unl[1, 0:1, :, :]
                writer.add_image('train/Un_Image', image, iter_num)
                outputs = torch.argmax(torch.softmax(out_unl, dim=1), dim=1, keepdim=True)
                writer.add_image('train/Un_Prediction', outputs[1, ...] * 50, iter_num)
                labs = unl_label[1, ...].unsqueeze(0) * 50
                writer.add_image('train/Un_GroundTruth', labs, iter_num)

                image_l = net_input_l[1, 0:1, :, :]
                writer.add_image('train/L_Image', image_l, iter_num)
                outputs_l = torch.argmax(torch.softmax(out_l, dim=1), dim=1, keepdim=True)
                writer.add_image('train/L_Prediction', outputs_l[1, ...] * 50, iter_num)
                labs_l = l_label[1, ...].unsqueeze(0) * 50
                writer.add_image('train/L_GroundTruth', labs_l, iter_num)

            if iter_num > 0 and iter_num % 200 == 0:
                model.eval()
                metric_list = 0.0
                for _, sampled_batch in enumerate(valloader):
                    metric_i = test_single_volume(sampled_batch["image"], sampled_batch["label"], model, classes= num_classes)
                    metric_list += np.array(metric_i)
                metric_list = metric_list / len(db_val)
                print(f'Metric list: {metric_list}') 
                for class_i in range(num_classes-1):
                    writer.add_scalar('info/val_{}_dice'.format(class_i+1), metric_list[class_i, 0], iter_num)
                    writer.add_scalar('info/val_{}_hd95'.format(class_i+1), metric_list[class_i, 1], iter_num)

                performance = np.mean(metric_list, axis=0)[0]
                writer.add_scalar('info/val_mean_dice', performance, iter_num)

                if performance > best_performance:
                    best_performance = performance
                    save_mode_path = os.path.join(snapshot_path, 'iter_{}_dice_{}.pth'.format(iter_num, round(best_performance, 4)))
                    save_best_path = os.path.join(snapshot_path,'{}_best_model.pth'.format(args.model))
                    torch.save(model.state_dict(), save_mode_path)
                    torch.save(model.state_dict(), save_best_path)

                logging.info('iteration %d : mean_dice : %f' % (iter_num, performance))
                model.train()

            if iter_num >= max_iterations:
                break
        if iter_num >= max_iterations:
            iterator.close()
            break
    writer.close()


In [102]:
# Params 
class params: 
    def __init__(self): 
        self.root_dir = '/kaggle/input/acdc-dataset/ACDC' 
        self.exp = 'BCP' 
        self.model = 'unet' 
        self.pretrain_iterations = 400 
        
        self.selftrain_iterations = 400
        self.batch_size = 24
        self.deterministic = 1 # What the fucck here
        self.base_lr = 0.01 
        self.patch_size = [256,256] 
        self.seed = 42 
        self.num_classes = 4 
        self.stage_name = 'self_train'

        # label and unlabel 
        self.labeled_bs = 12
        self.label_num = 7 
        self.u_weight = 0.5 

        # Cost 
        self.gpu = '0' 
        self.consistency = 0.1
        self.consistency_rampup = 200.0 
        self.magnitude = '6.0' 
        self.s_param = 6 


args = params()
# Trainig process 
if args.deterministic: 
    cudnn.benchmark = False 
    cudnn.deterministic = True 
    random.seed(args.seed) 
    np.random.seed(args.seed) 
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

# Path 
pre_snapshot_path = "./model/BCP/ACDC_{}_{}_labeled/pretrain".format(args.exp, args.label_num)
self_snapshot_path = "./model/BCP/ACDC_{}_{}_labeled/selftrain".format(args.exp, args.label_num)

print(f'Pretrain log path: {pre_snapshot_path + "/log.txt"}')
print(f'Self-train log path: {self_snapshot_path + "/log.txt"}')

for snapshot_path in [pre_snapshot_path, self_snapshot_path]: 
    if not os.path.exists(snapshot_path): 
        os.makedirs(snapshot_path, exist_ok= True)
#Pre_train
logging.basicConfig(filename=pre_snapshot_path+"/log.txt", level=logging.INFO, format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S')
logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
pre_train(args, pre_snapshot_path)

#Self_train
logging.basicConfig(filename=self_snapshot_path+"/log.txt", level=logging.INFO, format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S')
logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
self_train(args, pre_snapshot_path, self_snapshot_path)

Pretrain log path: ./model/BCP/ACDC_BCP_7_labeled/pretrain/log.txt
Self-train log path: ./model/BCP/ACDC_BCP_7_labeled/selftrain/log.txt
Total slices: 1312
Total slices: 20
Total slice is 1312, Labeled slice is 136
Start pre-training
Start pre-training
Start pre-training
Start pre-training
Start pre-training
Start pre-training
11 iterations per epoch
11 iterations per epoch
11 iterations per epoch
11 iterations per epoch
11 iterations per epoch
11 iterations per epoch


  0%|                                          | 0/37 [00:00<?, ?it/s]

iteration 1: loss 2.310845, mix_dice: 1.668241, mix_ce: 2.953450
iteration 1: loss 2.310845, mix_dice: 1.668241, mix_ce: 2.953450
iteration 1: loss 2.310845, mix_dice: 1.668241, mix_ce: 2.953450
iteration 1: loss 2.310845, mix_dice: 1.668241, mix_ce: 2.953450
iteration 1: loss 2.310845, mix_dice: 1.668241, mix_ce: 2.953450
iteration 1: loss 2.310845, mix_dice: 1.668241, mix_ce: 2.953450
iteration 2: loss 2.157312, mix_dice: 1.645143, mix_ce: 2.669481
iteration 2: loss 2.157312, mix_dice: 1.645143, mix_ce: 2.669481
iteration 2: loss 2.157312, mix_dice: 1.645143, mix_ce: 2.669481
iteration 2: loss 2.157312, mix_dice: 1.645143, mix_ce: 2.669481
iteration 2: loss 2.157312, mix_dice: 1.645143, mix_ce: 2.669481
iteration 2: loss 2.157312, mix_dice: 1.645143, mix_ce: 2.669481
iteration 3: loss 1.882932, mix_dice: 1.588059, mix_ce: 2.177804
iteration 3: loss 1.882932, mix_dice: 1.588059, mix_ce: 2.177804
iteration 3: loss 1.882932, mix_dice: 1.588059, mix_ce: 2.177804
iteration 3: loss 1.88293

  3%|▉                                 | 1/37 [00:01<00:57,  1.61s/it]

iteration 12: loss 0.918663, mix_dice: 1.384163, mix_ce: 0.453163
iteration 12: loss 0.918663, mix_dice: 1.384163, mix_ce: 0.453163
iteration 12: loss 0.918663, mix_dice: 1.384163, mix_ce: 0.453163
iteration 12: loss 0.918663, mix_dice: 1.384163, mix_ce: 0.453163
iteration 12: loss 0.918663, mix_dice: 1.384163, mix_ce: 0.453163
iteration 12: loss 0.918663, mix_dice: 1.384163, mix_ce: 0.453163
iteration 13: loss 0.876776, mix_dice: 1.340769, mix_ce: 0.412783
iteration 13: loss 0.876776, mix_dice: 1.340769, mix_ce: 0.412783
iteration 13: loss 0.876776, mix_dice: 1.340769, mix_ce: 0.412783
iteration 13: loss 0.876776, mix_dice: 1.340769, mix_ce: 0.412783
iteration 13: loss 0.876776, mix_dice: 1.340769, mix_ce: 0.412783
iteration 13: loss 0.876776, mix_dice: 1.340769, mix_ce: 0.412783
iteration 14: loss 0.945918, mix_dice: 1.409103, mix_ce: 0.482732
iteration 14: loss 0.945918, mix_dice: 1.409103, mix_ce: 0.482732
iteration 14: loss 0.945918, mix_dice: 1.409103, mix_ce: 0.482732
iteration 

  5%|█▊                                | 2/37 [00:03<00:55,  1.58s/it]

iteration 23: loss 0.887729, mix_dice: 1.380793, mix_ce: 0.394666
iteration 23: loss 0.887729, mix_dice: 1.380793, mix_ce: 0.394666
iteration 23: loss 0.887729, mix_dice: 1.380793, mix_ce: 0.394666
iteration 23: loss 0.887729, mix_dice: 1.380793, mix_ce: 0.394666
iteration 23: loss 0.887729, mix_dice: 1.380793, mix_ce: 0.394666
iteration 23: loss 0.887729, mix_dice: 1.380793, mix_ce: 0.394666
iteration 24: loss 0.880357, mix_dice: 1.355361, mix_ce: 0.405353
iteration 24: loss 0.880357, mix_dice: 1.355361, mix_ce: 0.405353
iteration 24: loss 0.880357, mix_dice: 1.355361, mix_ce: 0.405353
iteration 24: loss 0.880357, mix_dice: 1.355361, mix_ce: 0.405353
iteration 24: loss 0.880357, mix_dice: 1.355361, mix_ce: 0.405353
iteration 24: loss 0.880357, mix_dice: 1.355361, mix_ce: 0.405353
iteration 25: loss 1.029328, mix_dice: 1.393960, mix_ce: 0.664696
iteration 25: loss 1.029328, mix_dice: 1.393960, mix_ce: 0.664696
iteration 25: loss 1.029328, mix_dice: 1.393960, mix_ce: 0.664696
iteration 

  8%|██▊                               | 3/37 [00:04<00:54,  1.59s/it]

iteration 34: loss 0.931264, mix_dice: 1.398985, mix_ce: 0.463543
iteration 34: loss 0.931264, mix_dice: 1.398985, mix_ce: 0.463543
iteration 34: loss 0.931264, mix_dice: 1.398985, mix_ce: 0.463543
iteration 34: loss 0.931264, mix_dice: 1.398985, mix_ce: 0.463543
iteration 34: loss 0.931264, mix_dice: 1.398985, mix_ce: 0.463543
iteration 34: loss 0.931264, mix_dice: 1.398985, mix_ce: 0.463543
iteration 35: loss 0.841304, mix_dice: 1.299525, mix_ce: 0.383083
iteration 35: loss 0.841304, mix_dice: 1.299525, mix_ce: 0.383083
iteration 35: loss 0.841304, mix_dice: 1.299525, mix_ce: 0.383083
iteration 35: loss 0.841304, mix_dice: 1.299525, mix_ce: 0.383083
iteration 35: loss 0.841304, mix_dice: 1.299525, mix_ce: 0.383083
iteration 35: loss 0.841304, mix_dice: 1.299525, mix_ce: 0.383083
iteration 36: loss 0.909168, mix_dice: 1.357154, mix_ce: 0.461182
iteration 36: loss 0.909168, mix_dice: 1.357154, mix_ce: 0.461182
iteration 36: loss 0.909168, mix_dice: 1.357154, mix_ce: 0.461182
iteration 

 11%|███▋                              | 4/37 [00:06<00:52,  1.59s/it]

iteration 45: loss 0.789553, mix_dice: 1.280781, mix_ce: 0.298326
iteration 45: loss 0.789553, mix_dice: 1.280781, mix_ce: 0.298326
iteration 45: loss 0.789553, mix_dice: 1.280781, mix_ce: 0.298326
iteration 45: loss 0.789553, mix_dice: 1.280781, mix_ce: 0.298326
iteration 45: loss 0.789553, mix_dice: 1.280781, mix_ce: 0.298326
iteration 45: loss 0.789553, mix_dice: 1.280781, mix_ce: 0.298326
iteration 46: loss 0.930843, mix_dice: 1.406074, mix_ce: 0.455612
iteration 46: loss 0.930843, mix_dice: 1.406074, mix_ce: 0.455612
iteration 46: loss 0.930843, mix_dice: 1.406074, mix_ce: 0.455612
iteration 46: loss 0.930843, mix_dice: 1.406074, mix_ce: 0.455612
iteration 46: loss 0.930843, mix_dice: 1.406074, mix_ce: 0.455612
iteration 46: loss 0.930843, mix_dice: 1.406074, mix_ce: 0.455612
iteration 47: loss 0.856728, mix_dice: 1.320818, mix_ce: 0.392637
iteration 47: loss 0.856728, mix_dice: 1.320818, mix_ce: 0.392637
iteration 47: loss 0.856728, mix_dice: 1.320818, mix_ce: 0.392637
iteration 

 14%|████▌                             | 5/37 [00:08<00:51,  1.62s/it]

iteration 56: loss 0.783423, mix_dice: 1.310668, mix_ce: 0.256179
iteration 56: loss 0.783423, mix_dice: 1.310668, mix_ce: 0.256179
iteration 56: loss 0.783423, mix_dice: 1.310668, mix_ce: 0.256179
iteration 56: loss 0.783423, mix_dice: 1.310668, mix_ce: 0.256179
iteration 56: loss 0.783423, mix_dice: 1.310668, mix_ce: 0.256179
iteration 56: loss 0.783423, mix_dice: 1.310668, mix_ce: 0.256179
iteration 57: loss 0.795948, mix_dice: 1.239373, mix_ce: 0.352523
iteration 57: loss 0.795948, mix_dice: 1.239373, mix_ce: 0.352523
iteration 57: loss 0.795948, mix_dice: 1.239373, mix_ce: 0.352523
iteration 57: loss 0.795948, mix_dice: 1.239373, mix_ce: 0.352523
iteration 57: loss 0.795948, mix_dice: 1.239373, mix_ce: 0.352523
iteration 57: loss 0.795948, mix_dice: 1.239373, mix_ce: 0.352523
iteration 58: loss 0.854817, mix_dice: 1.292881, mix_ce: 0.416753
iteration 58: loss 0.854817, mix_dice: 1.292881, mix_ce: 0.416753
iteration 58: loss 0.854817, mix_dice: 1.292881, mix_ce: 0.416753
iteration 

 16%|█████▌                            | 6/37 [00:09<00:50,  1.62s/it]

iteration 67: loss 0.837385, mix_dice: 1.357588, mix_ce: 0.317183
iteration 67: loss 0.837385, mix_dice: 1.357588, mix_ce: 0.317183
iteration 67: loss 0.837385, mix_dice: 1.357588, mix_ce: 0.317183
iteration 67: loss 0.837385, mix_dice: 1.357588, mix_ce: 0.317183
iteration 67: loss 0.837385, mix_dice: 1.357588, mix_ce: 0.317183
iteration 67: loss 0.837385, mix_dice: 1.357588, mix_ce: 0.317183
iteration 68: loss 0.772601, mix_dice: 1.261007, mix_ce: 0.284195
iteration 68: loss 0.772601, mix_dice: 1.261007, mix_ce: 0.284195
iteration 68: loss 0.772601, mix_dice: 1.261007, mix_ce: 0.284195
iteration 68: loss 0.772601, mix_dice: 1.261007, mix_ce: 0.284195
iteration 68: loss 0.772601, mix_dice: 1.261007, mix_ce: 0.284195
iteration 68: loss 0.772601, mix_dice: 1.261007, mix_ce: 0.284195
iteration 69: loss 0.842725, mix_dice: 1.334349, mix_ce: 0.351101
iteration 69: loss 0.842725, mix_dice: 1.334349, mix_ce: 0.351101
iteration 69: loss 0.842725, mix_dice: 1.334349, mix_ce: 0.351101
iteration 

 19%|██████▍                           | 7/37 [00:11<00:48,  1.61s/it]

iteration 78: loss 0.777076, mix_dice: 1.257473, mix_ce: 0.296679
iteration 78: loss 0.777076, mix_dice: 1.257473, mix_ce: 0.296679
iteration 78: loss 0.777076, mix_dice: 1.257473, mix_ce: 0.296679
iteration 78: loss 0.777076, mix_dice: 1.257473, mix_ce: 0.296679
iteration 78: loss 0.777076, mix_dice: 1.257473, mix_ce: 0.296679
iteration 78: loss 0.777076, mix_dice: 1.257473, mix_ce: 0.296679
iteration 79: loss 0.853315, mix_dice: 1.356631, mix_ce: 0.349998
iteration 79: loss 0.853315, mix_dice: 1.356631, mix_ce: 0.349998
iteration 79: loss 0.853315, mix_dice: 1.356631, mix_ce: 0.349998
iteration 79: loss 0.853315, mix_dice: 1.356631, mix_ce: 0.349998
iteration 79: loss 0.853315, mix_dice: 1.356631, mix_ce: 0.349998
iteration 79: loss 0.853315, mix_dice: 1.356631, mix_ce: 0.349998
iteration 80: loss 0.818762, mix_dice: 1.353292, mix_ce: 0.284231
iteration 80: loss 0.818762, mix_dice: 1.353292, mix_ce: 0.284231
iteration 80: loss 0.818762, mix_dice: 1.353292, mix_ce: 0.284231
iteration 

 22%|███████▎                          | 8/37 [00:12<00:45,  1.58s/it]

iteration 89: loss 0.745341, mix_dice: 1.198321, mix_ce: 0.292361
iteration 89: loss 0.745341, mix_dice: 1.198321, mix_ce: 0.292361
iteration 89: loss 0.745341, mix_dice: 1.198321, mix_ce: 0.292361
iteration 89: loss 0.745341, mix_dice: 1.198321, mix_ce: 0.292361
iteration 89: loss 0.745341, mix_dice: 1.198321, mix_ce: 0.292361
iteration 89: loss 0.745341, mix_dice: 1.198321, mix_ce: 0.292361
iteration 90: loss 0.784154, mix_dice: 1.254346, mix_ce: 0.313961
iteration 90: loss 0.784154, mix_dice: 1.254346, mix_ce: 0.313961
iteration 90: loss 0.784154, mix_dice: 1.254346, mix_ce: 0.313961
iteration 90: loss 0.784154, mix_dice: 1.254346, mix_ce: 0.313961
iteration 90: loss 0.784154, mix_dice: 1.254346, mix_ce: 0.313961
iteration 90: loss 0.784154, mix_dice: 1.254346, mix_ce: 0.313961
iteration 91: loss 0.808774, mix_dice: 1.313331, mix_ce: 0.304216
iteration 91: loss 0.808774, mix_dice: 1.313331, mix_ce: 0.304216
iteration 91: loss 0.808774, mix_dice: 1.313331, mix_ce: 0.304216
iteration 

 24%|████████▎                         | 9/37 [00:14<00:44,  1.58s/it]

iteration 100: loss 0.776036, mix_dice: 1.224990, mix_ce: 0.327082
iteration 100: loss 0.776036, mix_dice: 1.224990, mix_ce: 0.327082
iteration 100: loss 0.776036, mix_dice: 1.224990, mix_ce: 0.327082
iteration 100: loss 0.776036, mix_dice: 1.224990, mix_ce: 0.327082
iteration 100: loss 0.776036, mix_dice: 1.224990, mix_ce: 0.327082
iteration 100: loss 0.776036, mix_dice: 1.224990, mix_ce: 0.327082
iteration 101: loss 0.739965, mix_dice: 1.245547, mix_ce: 0.234383
iteration 101: loss 0.739965, mix_dice: 1.245547, mix_ce: 0.234383
iteration 101: loss 0.739965, mix_dice: 1.245547, mix_ce: 0.234383
iteration 101: loss 0.739965, mix_dice: 1.245547, mix_ce: 0.234383
iteration 101: loss 0.739965, mix_dice: 1.245547, mix_ce: 0.234383
iteration 101: loss 0.739965, mix_dice: 1.245547, mix_ce: 0.234383
iteration 102: loss 0.778821, mix_dice: 1.279531, mix_ce: 0.278111
iteration 102: loss 0.778821, mix_dice: 1.279531, mix_ce: 0.278111
iteration 102: loss 0.778821, mix_dice: 1.279531, mix_ce: 0.27

 27%|████████▉                        | 10/37 [00:15<00:42,  1.58s/it]

iteration 111: loss 0.739956, mix_dice: 1.264417, mix_ce: 0.215495
iteration 111: loss 0.739956, mix_dice: 1.264417, mix_ce: 0.215495
iteration 111: loss 0.739956, mix_dice: 1.264417, mix_ce: 0.215495
iteration 111: loss 0.739956, mix_dice: 1.264417, mix_ce: 0.215495
iteration 111: loss 0.739956, mix_dice: 1.264417, mix_ce: 0.215495
iteration 111: loss 0.739956, mix_dice: 1.264417, mix_ce: 0.215495
iteration 112: loss 0.741110, mix_dice: 1.203413, mix_ce: 0.278807
iteration 112: loss 0.741110, mix_dice: 1.203413, mix_ce: 0.278807
iteration 112: loss 0.741110, mix_dice: 1.203413, mix_ce: 0.278807
iteration 112: loss 0.741110, mix_dice: 1.203413, mix_ce: 0.278807
iteration 112: loss 0.741110, mix_dice: 1.203413, mix_ce: 0.278807
iteration 112: loss 0.741110, mix_dice: 1.203413, mix_ce: 0.278807
iteration 113: loss 0.725736, mix_dice: 1.193627, mix_ce: 0.257845
iteration 113: loss 0.725736, mix_dice: 1.193627, mix_ce: 0.257845
iteration 113: loss 0.725736, mix_dice: 1.193627, mix_ce: 0.25

 30%|█████████▊                       | 11/37 [00:17<00:42,  1.62s/it]

iteration 122: loss 0.700146, mix_dice: 1.153044, mix_ce: 0.247248
iteration 122: loss 0.700146, mix_dice: 1.153044, mix_ce: 0.247248
iteration 122: loss 0.700146, mix_dice: 1.153044, mix_ce: 0.247248
iteration 122: loss 0.700146, mix_dice: 1.153044, mix_ce: 0.247248
iteration 122: loss 0.700146, mix_dice: 1.153044, mix_ce: 0.247248
iteration 122: loss 0.700146, mix_dice: 1.153044, mix_ce: 0.247248
iteration 123: loss 0.772592, mix_dice: 1.236474, mix_ce: 0.308710
iteration 123: loss 0.772592, mix_dice: 1.236474, mix_ce: 0.308710
iteration 123: loss 0.772592, mix_dice: 1.236474, mix_ce: 0.308710
iteration 123: loss 0.772592, mix_dice: 1.236474, mix_ce: 0.308710
iteration 123: loss 0.772592, mix_dice: 1.236474, mix_ce: 0.308710
iteration 123: loss 0.772592, mix_dice: 1.236474, mix_ce: 0.308710
iteration 124: loss 0.769758, mix_dice: 1.243160, mix_ce: 0.296355
iteration 124: loss 0.769758, mix_dice: 1.243160, mix_ce: 0.296355
iteration 124: loss 0.769758, mix_dice: 1.243160, mix_ce: 0.29

 32%|██████████▋                      | 12/37 [00:19<00:39,  1.60s/it]

iteration 133: loss 0.796969, mix_dice: 1.257009, mix_ce: 0.336929
iteration 133: loss 0.796969, mix_dice: 1.257009, mix_ce: 0.336929
iteration 133: loss 0.796969, mix_dice: 1.257009, mix_ce: 0.336929
iteration 133: loss 0.796969, mix_dice: 1.257009, mix_ce: 0.336929
iteration 133: loss 0.796969, mix_dice: 1.257009, mix_ce: 0.336929
iteration 133: loss 0.796969, mix_dice: 1.257009, mix_ce: 0.336929
iteration 134: loss 0.749161, mix_dice: 1.252123, mix_ce: 0.246199
iteration 134: loss 0.749161, mix_dice: 1.252123, mix_ce: 0.246199
iteration 134: loss 0.749161, mix_dice: 1.252123, mix_ce: 0.246199
iteration 134: loss 0.749161, mix_dice: 1.252123, mix_ce: 0.246199
iteration 134: loss 0.749161, mix_dice: 1.252123, mix_ce: 0.246199
iteration 134: loss 0.749161, mix_dice: 1.252123, mix_ce: 0.246199
iteration 135: loss 0.751676, mix_dice: 1.239457, mix_ce: 0.263896
iteration 135: loss 0.751676, mix_dice: 1.239457, mix_ce: 0.263896
iteration 135: loss 0.751676, mix_dice: 1.239457, mix_ce: 0.26

 35%|███████████▌                     | 13/37 [00:20<00:37,  1.57s/it]

iteration 144: loss 0.820679, mix_dice: 1.302188, mix_ce: 0.339171
iteration 144: loss 0.820679, mix_dice: 1.302188, mix_ce: 0.339171
iteration 144: loss 0.820679, mix_dice: 1.302188, mix_ce: 0.339171
iteration 144: loss 0.820679, mix_dice: 1.302188, mix_ce: 0.339171
iteration 144: loss 0.820679, mix_dice: 1.302188, mix_ce: 0.339171
iteration 144: loss 0.820679, mix_dice: 1.302188, mix_ce: 0.339171
iteration 145: loss 0.714524, mix_dice: 1.195220, mix_ce: 0.233828
iteration 145: loss 0.714524, mix_dice: 1.195220, mix_ce: 0.233828
iteration 145: loss 0.714524, mix_dice: 1.195220, mix_ce: 0.233828
iteration 145: loss 0.714524, mix_dice: 1.195220, mix_ce: 0.233828
iteration 145: loss 0.714524, mix_dice: 1.195220, mix_ce: 0.233828
iteration 145: loss 0.714524, mix_dice: 1.195220, mix_ce: 0.233828
iteration 146: loss 0.767065, mix_dice: 1.211486, mix_ce: 0.322643
iteration 146: loss 0.767065, mix_dice: 1.211486, mix_ce: 0.322643
iteration 146: loss 0.767065, mix_dice: 1.211486, mix_ce: 0.32

 38%|████████████▍                    | 14/37 [00:22<00:36,  1.59s/it]

iteration 155: loss 0.714134, mix_dice: 1.207251, mix_ce: 0.221016
iteration 155: loss 0.714134, mix_dice: 1.207251, mix_ce: 0.221016
iteration 155: loss 0.714134, mix_dice: 1.207251, mix_ce: 0.221016
iteration 155: loss 0.714134, mix_dice: 1.207251, mix_ce: 0.221016
iteration 155: loss 0.714134, mix_dice: 1.207251, mix_ce: 0.221016
iteration 155: loss 0.714134, mix_dice: 1.207251, mix_ce: 0.221016
iteration 156: loss 0.699201, mix_dice: 1.174985, mix_ce: 0.223417
iteration 156: loss 0.699201, mix_dice: 1.174985, mix_ce: 0.223417
iteration 156: loss 0.699201, mix_dice: 1.174985, mix_ce: 0.223417
iteration 156: loss 0.699201, mix_dice: 1.174985, mix_ce: 0.223417
iteration 156: loss 0.699201, mix_dice: 1.174985, mix_ce: 0.223417
iteration 156: loss 0.699201, mix_dice: 1.174985, mix_ce: 0.223417
iteration 157: loss 0.821194, mix_dice: 1.311605, mix_ce: 0.330782
iteration 157: loss 0.821194, mix_dice: 1.311605, mix_ce: 0.330782
iteration 157: loss 0.821194, mix_dice: 1.311605, mix_ce: 0.33

 41%|█████████████▍                   | 15/37 [00:24<00:37,  1.69s/it]

iteration 166: loss 0.663101, mix_dice: 1.149075, mix_ce: 0.177126
iteration 166: loss 0.663101, mix_dice: 1.149075, mix_ce: 0.177126
iteration 166: loss 0.663101, mix_dice: 1.149075, mix_ce: 0.177126
iteration 166: loss 0.663101, mix_dice: 1.149075, mix_ce: 0.177126
iteration 166: loss 0.663101, mix_dice: 1.149075, mix_ce: 0.177126
iteration 166: loss 0.663101, mix_dice: 1.149075, mix_ce: 0.177126
iteration 167: loss 0.749003, mix_dice: 1.209995, mix_ce: 0.288011
iteration 167: loss 0.749003, mix_dice: 1.209995, mix_ce: 0.288011
iteration 167: loss 0.749003, mix_dice: 1.209995, mix_ce: 0.288011
iteration 167: loss 0.749003, mix_dice: 1.209995, mix_ce: 0.288011
iteration 167: loss 0.749003, mix_dice: 1.209995, mix_ce: 0.288011
iteration 167: loss 0.749003, mix_dice: 1.209995, mix_ce: 0.288011
iteration 168: loss 0.750760, mix_dice: 1.223988, mix_ce: 0.277532
iteration 168: loss 0.750760, mix_dice: 1.223988, mix_ce: 0.277532
iteration 168: loss 0.750760, mix_dice: 1.223988, mix_ce: 0.27

 43%|██████████████▎                  | 16/37 [00:25<00:34,  1.63s/it]

iteration 177: loss 0.690356, mix_dice: 1.174915, mix_ce: 0.205797
iteration 177: loss 0.690356, mix_dice: 1.174915, mix_ce: 0.205797
iteration 177: loss 0.690356, mix_dice: 1.174915, mix_ce: 0.205797
iteration 177: loss 0.690356, mix_dice: 1.174915, mix_ce: 0.205797
iteration 177: loss 0.690356, mix_dice: 1.174915, mix_ce: 0.205797
iteration 177: loss 0.690356, mix_dice: 1.174915, mix_ce: 0.205797
iteration 178: loss 0.764782, mix_dice: 1.213889, mix_ce: 0.315674
iteration 178: loss 0.764782, mix_dice: 1.213889, mix_ce: 0.315674
iteration 178: loss 0.764782, mix_dice: 1.213889, mix_ce: 0.315674
iteration 178: loss 0.764782, mix_dice: 1.213889, mix_ce: 0.315674
iteration 178: loss 0.764782, mix_dice: 1.213889, mix_ce: 0.315674
iteration 178: loss 0.764782, mix_dice: 1.213889, mix_ce: 0.315674
iteration 179: loss 0.757644, mix_dice: 1.241252, mix_ce: 0.274036
iteration 179: loss 0.757644, mix_dice: 1.241252, mix_ce: 0.274036
iteration 179: loss 0.757644, mix_dice: 1.241252, mix_ce: 0.27

 46%|███████████████▏                 | 17/37 [00:27<00:33,  1.66s/it]

iteration 188: loss 0.743725, mix_dice: 1.217616, mix_ce: 0.269834
iteration 188: loss 0.743725, mix_dice: 1.217616, mix_ce: 0.269834
iteration 188: loss 0.743725, mix_dice: 1.217616, mix_ce: 0.269834
iteration 188: loss 0.743725, mix_dice: 1.217616, mix_ce: 0.269834
iteration 188: loss 0.743725, mix_dice: 1.217616, mix_ce: 0.269834
iteration 188: loss 0.743725, mix_dice: 1.217616, mix_ce: 0.269834
iteration 189: loss 0.776233, mix_dice: 1.258067, mix_ce: 0.294398
iteration 189: loss 0.776233, mix_dice: 1.258067, mix_ce: 0.294398
iteration 189: loss 0.776233, mix_dice: 1.258067, mix_ce: 0.294398
iteration 189: loss 0.776233, mix_dice: 1.258067, mix_ce: 0.294398
iteration 189: loss 0.776233, mix_dice: 1.258067, mix_ce: 0.294398
iteration 189: loss 0.776233, mix_dice: 1.258067, mix_ce: 0.294398
iteration 190: loss 0.601441, mix_dice: 1.026332, mix_ce: 0.176549
iteration 190: loss 0.601441, mix_dice: 1.026332, mix_ce: 0.176549
iteration 190: loss 0.601441, mix_dice: 1.026332, mix_ce: 0.17

 49%|████████████████                 | 18/37 [00:29<00:31,  1.64s/it]

iteration 199: loss 0.764951, mix_dice: 1.216268, mix_ce: 0.313634
iteration 199: loss 0.764951, mix_dice: 1.216268, mix_ce: 0.313634
iteration 199: loss 0.764951, mix_dice: 1.216268, mix_ce: 0.313634
iteration 199: loss 0.764951, mix_dice: 1.216268, mix_ce: 0.313634
iteration 199: loss 0.764951, mix_dice: 1.216268, mix_ce: 0.313634
iteration 199: loss 0.764951, mix_dice: 1.216268, mix_ce: 0.313634
iteration 200: loss 0.616343, mix_dice: 1.057307, mix_ce: 0.175380
iteration 200: loss 0.616343, mix_dice: 1.057307, mix_ce: 0.175380
iteration 200: loss 0.616343, mix_dice: 1.057307, mix_ce: 0.175380
iteration 200: loss 0.616343, mix_dice: 1.057307, mix_ce: 0.175380
iteration 200: loss 0.616343, mix_dice: 1.057307, mix_ce: 0.175380
iteration 200: loss 0.616343, mix_dice: 1.057307, mix_ce: 0.175380
iteration 200 : mean dice : 0.286311
iteration 200 : mean dice : 0.286311
iteration 200 : mean dice : 0.286311
iteration 200 : mean dice : 0.286311
iteration 200 : mean dice : 0.286311
iteration 2

 51%|████████████████▉                | 19/37 [00:39<01:14,  4.15s/it]

iteration 210: loss 0.494499, mix_dice: 0.846162, mix_ce: 0.142836
iteration 210: loss 0.494499, mix_dice: 0.846162, mix_ce: 0.142836
iteration 210: loss 0.494499, mix_dice: 0.846162, mix_ce: 0.142836
iteration 210: loss 0.494499, mix_dice: 0.846162, mix_ce: 0.142836
iteration 210: loss 0.494499, mix_dice: 0.846162, mix_ce: 0.142836
iteration 210: loss 0.494499, mix_dice: 0.846162, mix_ce: 0.142836
iteration 211: loss 0.658382, mix_dice: 1.115425, mix_ce: 0.201339
iteration 211: loss 0.658382, mix_dice: 1.115425, mix_ce: 0.201339
iteration 211: loss 0.658382, mix_dice: 1.115425, mix_ce: 0.201339
iteration 211: loss 0.658382, mix_dice: 1.115425, mix_ce: 0.201339
iteration 211: loss 0.658382, mix_dice: 1.115425, mix_ce: 0.201339
iteration 211: loss 0.658382, mix_dice: 1.115425, mix_ce: 0.201339
iteration 212: loss 0.614678, mix_dice: 1.005532, mix_ce: 0.223825
iteration 212: loss 0.614678, mix_dice: 1.005532, mix_ce: 0.223825
iteration 212: loss 0.614678, mix_dice: 1.005532, mix_ce: 0.22

 54%|█████████████████▊               | 20/37 [00:40<00:57,  3.39s/it]

iteration 221: loss 0.651610, mix_dice: 1.076718, mix_ce: 0.226501
iteration 221: loss 0.651610, mix_dice: 1.076718, mix_ce: 0.226501
iteration 221: loss 0.651610, mix_dice: 1.076718, mix_ce: 0.226501
iteration 221: loss 0.651610, mix_dice: 1.076718, mix_ce: 0.226501
iteration 221: loss 0.651610, mix_dice: 1.076718, mix_ce: 0.226501
iteration 221: loss 0.651610, mix_dice: 1.076718, mix_ce: 0.226501
iteration 222: loss 0.632352, mix_dice: 1.095034, mix_ce: 0.169670
iteration 222: loss 0.632352, mix_dice: 1.095034, mix_ce: 0.169670
iteration 222: loss 0.632352, mix_dice: 1.095034, mix_ce: 0.169670
iteration 222: loss 0.632352, mix_dice: 1.095034, mix_ce: 0.169670
iteration 222: loss 0.632352, mix_dice: 1.095034, mix_ce: 0.169670
iteration 222: loss 0.632352, mix_dice: 1.095034, mix_ce: 0.169670
iteration 223: loss 0.677231, mix_dice: 1.121097, mix_ce: 0.233366
iteration 223: loss 0.677231, mix_dice: 1.121097, mix_ce: 0.233366
iteration 223: loss 0.677231, mix_dice: 1.121097, mix_ce: 0.23

 57%|██████████████████▋              | 21/37 [00:42<00:45,  2.84s/it]

iteration 232: loss 0.731365, mix_dice: 1.208361, mix_ce: 0.254368
iteration 232: loss 0.731365, mix_dice: 1.208361, mix_ce: 0.254368
iteration 232: loss 0.731365, mix_dice: 1.208361, mix_ce: 0.254368
iteration 232: loss 0.731365, mix_dice: 1.208361, mix_ce: 0.254368
iteration 232: loss 0.731365, mix_dice: 1.208361, mix_ce: 0.254368
iteration 232: loss 0.731365, mix_dice: 1.208361, mix_ce: 0.254368
iteration 233: loss 0.668135, mix_dice: 1.131687, mix_ce: 0.204583
iteration 233: loss 0.668135, mix_dice: 1.131687, mix_ce: 0.204583
iteration 233: loss 0.668135, mix_dice: 1.131687, mix_ce: 0.204583
iteration 233: loss 0.668135, mix_dice: 1.131687, mix_ce: 0.204583
iteration 233: loss 0.668135, mix_dice: 1.131687, mix_ce: 0.204583
iteration 233: loss 0.668135, mix_dice: 1.131687, mix_ce: 0.204583
iteration 234: loss 0.677286, mix_dice: 1.168636, mix_ce: 0.185936
iteration 234: loss 0.677286, mix_dice: 1.168636, mix_ce: 0.185936
iteration 234: loss 0.677286, mix_dice: 1.168636, mix_ce: 0.18

 59%|███████████████████▌             | 22/37 [00:43<00:37,  2.48s/it]

iteration 243: loss 0.590346, mix_dice: 0.989833, mix_ce: 0.190860
iteration 243: loss 0.590346, mix_dice: 0.989833, mix_ce: 0.190860
iteration 243: loss 0.590346, mix_dice: 0.989833, mix_ce: 0.190860
iteration 243: loss 0.590346, mix_dice: 0.989833, mix_ce: 0.190860
iteration 243: loss 0.590346, mix_dice: 0.989833, mix_ce: 0.190860
iteration 243: loss 0.590346, mix_dice: 0.989833, mix_ce: 0.190860
iteration 244: loss 0.663403, mix_dice: 1.134505, mix_ce: 0.192302
iteration 244: loss 0.663403, mix_dice: 1.134505, mix_ce: 0.192302
iteration 244: loss 0.663403, mix_dice: 1.134505, mix_ce: 0.192302
iteration 244: loss 0.663403, mix_dice: 1.134505, mix_ce: 0.192302
iteration 244: loss 0.663403, mix_dice: 1.134505, mix_ce: 0.192302
iteration 244: loss 0.663403, mix_dice: 1.134505, mix_ce: 0.192302
iteration 245: loss 0.540797, mix_dice: 0.812622, mix_ce: 0.268972
iteration 245: loss 0.540797, mix_dice: 0.812622, mix_ce: 0.268972
iteration 245: loss 0.540797, mix_dice: 0.812622, mix_ce: 0.26

 62%|████████████████████▌            | 23/37 [00:45<00:30,  2.19s/it]

iteration 254: loss 0.690432, mix_dice: 1.180039, mix_ce: 0.200825
iteration 254: loss 0.690432, mix_dice: 1.180039, mix_ce: 0.200825
iteration 254: loss 0.690432, mix_dice: 1.180039, mix_ce: 0.200825
iteration 254: loss 0.690432, mix_dice: 1.180039, mix_ce: 0.200825
iteration 254: loss 0.690432, mix_dice: 1.180039, mix_ce: 0.200825
iteration 254: loss 0.690432, mix_dice: 1.180039, mix_ce: 0.200825
iteration 255: loss 0.607199, mix_dice: 1.089597, mix_ce: 0.124802
iteration 255: loss 0.607199, mix_dice: 1.089597, mix_ce: 0.124802
iteration 255: loss 0.607199, mix_dice: 1.089597, mix_ce: 0.124802
iteration 255: loss 0.607199, mix_dice: 1.089597, mix_ce: 0.124802
iteration 255: loss 0.607199, mix_dice: 1.089597, mix_ce: 0.124802
iteration 255: loss 0.607199, mix_dice: 1.089597, mix_ce: 0.124802
iteration 256: loss 0.654333, mix_dice: 1.092944, mix_ce: 0.215723
iteration 256: loss 0.654333, mix_dice: 1.092944, mix_ce: 0.215723
iteration 256: loss 0.654333, mix_dice: 1.092944, mix_ce: 0.21

 65%|█████████████████████▍           | 24/37 [00:47<00:26,  2.05s/it]

iteration 265: loss 0.669817, mix_dice: 1.146308, mix_ce: 0.193325
iteration 265: loss 0.669817, mix_dice: 1.146308, mix_ce: 0.193325
iteration 265: loss 0.669817, mix_dice: 1.146308, mix_ce: 0.193325
iteration 265: loss 0.669817, mix_dice: 1.146308, mix_ce: 0.193325
iteration 265: loss 0.669817, mix_dice: 1.146308, mix_ce: 0.193325
iteration 265: loss 0.669817, mix_dice: 1.146308, mix_ce: 0.193325
iteration 266: loss 0.561730, mix_dice: 0.972621, mix_ce: 0.150838
iteration 266: loss 0.561730, mix_dice: 0.972621, mix_ce: 0.150838
iteration 266: loss 0.561730, mix_dice: 0.972621, mix_ce: 0.150838
iteration 266: loss 0.561730, mix_dice: 0.972621, mix_ce: 0.150838
iteration 266: loss 0.561730, mix_dice: 0.972621, mix_ce: 0.150838
iteration 266: loss 0.561730, mix_dice: 0.972621, mix_ce: 0.150838
iteration 267: loss 0.609505, mix_dice: 1.046887, mix_ce: 0.172124
iteration 267: loss 0.609505, mix_dice: 1.046887, mix_ce: 0.172124
iteration 267: loss 0.609505, mix_dice: 1.046887, mix_ce: 0.17

 68%|██████████████████████▎          | 25/37 [00:48<00:23,  1.92s/it]

iteration 276: loss 0.621224, mix_dice: 1.055420, mix_ce: 0.187028
iteration 276: loss 0.621224, mix_dice: 1.055420, mix_ce: 0.187028
iteration 276: loss 0.621224, mix_dice: 1.055420, mix_ce: 0.187028
iteration 276: loss 0.621224, mix_dice: 1.055420, mix_ce: 0.187028
iteration 276: loss 0.621224, mix_dice: 1.055420, mix_ce: 0.187028
iteration 276: loss 0.621224, mix_dice: 1.055420, mix_ce: 0.187028
iteration 277: loss 0.791406, mix_dice: 1.234500, mix_ce: 0.348313
iteration 277: loss 0.791406, mix_dice: 1.234500, mix_ce: 0.348313
iteration 277: loss 0.791406, mix_dice: 1.234500, mix_ce: 0.348313
iteration 277: loss 0.791406, mix_dice: 1.234500, mix_ce: 0.348313
iteration 277: loss 0.791406, mix_dice: 1.234500, mix_ce: 0.348313
iteration 277: loss 0.791406, mix_dice: 1.234500, mix_ce: 0.348313
iteration 278: loss 0.621651, mix_dice: 1.038096, mix_ce: 0.205207
iteration 278: loss 0.621651, mix_dice: 1.038096, mix_ce: 0.205207
iteration 278: loss 0.621651, mix_dice: 1.038096, mix_ce: 0.20

 70%|███████████████████████▏         | 26/37 [00:50<00:20,  1.84s/it]

iteration 287: loss 0.612391, mix_dice: 1.078508, mix_ce: 0.146274
iteration 287: loss 0.612391, mix_dice: 1.078508, mix_ce: 0.146274
iteration 287: loss 0.612391, mix_dice: 1.078508, mix_ce: 0.146274
iteration 287: loss 0.612391, mix_dice: 1.078508, mix_ce: 0.146274
iteration 287: loss 0.612391, mix_dice: 1.078508, mix_ce: 0.146274
iteration 287: loss 0.612391, mix_dice: 1.078508, mix_ce: 0.146274
iteration 288: loss 0.566632, mix_dice: 0.985644, mix_ce: 0.147620
iteration 288: loss 0.566632, mix_dice: 0.985644, mix_ce: 0.147620
iteration 288: loss 0.566632, mix_dice: 0.985644, mix_ce: 0.147620
iteration 288: loss 0.566632, mix_dice: 0.985644, mix_ce: 0.147620
iteration 288: loss 0.566632, mix_dice: 0.985644, mix_ce: 0.147620
iteration 288: loss 0.566632, mix_dice: 0.985644, mix_ce: 0.147620
iteration 289: loss 0.645944, mix_dice: 1.089104, mix_ce: 0.202784
iteration 289: loss 0.645944, mix_dice: 1.089104, mix_ce: 0.202784
iteration 289: loss 0.645944, mix_dice: 1.089104, mix_ce: 0.20

 73%|████████████████████████         | 27/37 [00:51<00:17,  1.77s/it]

iteration 298: loss 0.740294, mix_dice: 1.252752, mix_ce: 0.227835
iteration 298: loss 0.740294, mix_dice: 1.252752, mix_ce: 0.227835
iteration 298: loss 0.740294, mix_dice: 1.252752, mix_ce: 0.227835
iteration 298: loss 0.740294, mix_dice: 1.252752, mix_ce: 0.227835
iteration 298: loss 0.740294, mix_dice: 1.252752, mix_ce: 0.227835
iteration 298: loss 0.740294, mix_dice: 1.252752, mix_ce: 0.227835
iteration 299: loss 0.585642, mix_dice: 0.990110, mix_ce: 0.181175
iteration 299: loss 0.585642, mix_dice: 0.990110, mix_ce: 0.181175
iteration 299: loss 0.585642, mix_dice: 0.990110, mix_ce: 0.181175
iteration 299: loss 0.585642, mix_dice: 0.990110, mix_ce: 0.181175
iteration 299: loss 0.585642, mix_dice: 0.990110, mix_ce: 0.181175
iteration 299: loss 0.585642, mix_dice: 0.990110, mix_ce: 0.181175
iteration 300: loss 0.640323, mix_dice: 1.106883, mix_ce: 0.173763
iteration 300: loss 0.640323, mix_dice: 1.106883, mix_ce: 0.173763
iteration 300: loss 0.640323, mix_dice: 1.106883, mix_ce: 0.17

 76%|████████████████████████▉        | 28/37 [00:53<00:15,  1.72s/it]

iteration 309: loss 0.646084, mix_dice: 1.092225, mix_ce: 0.199943
iteration 309: loss 0.646084, mix_dice: 1.092225, mix_ce: 0.199943
iteration 309: loss 0.646084, mix_dice: 1.092225, mix_ce: 0.199943
iteration 309: loss 0.646084, mix_dice: 1.092225, mix_ce: 0.199943
iteration 309: loss 0.646084, mix_dice: 1.092225, mix_ce: 0.199943
iteration 309: loss 0.646084, mix_dice: 1.092225, mix_ce: 0.199943
iteration 310: loss 0.523852, mix_dice: 0.909975, mix_ce: 0.137729
iteration 310: loss 0.523852, mix_dice: 0.909975, mix_ce: 0.137729
iteration 310: loss 0.523852, mix_dice: 0.909975, mix_ce: 0.137729
iteration 310: loss 0.523852, mix_dice: 0.909975, mix_ce: 0.137729
iteration 310: loss 0.523852, mix_dice: 0.909975, mix_ce: 0.137729
iteration 310: loss 0.523852, mix_dice: 0.909975, mix_ce: 0.137729
iteration 311: loss 0.650752, mix_dice: 1.092210, mix_ce: 0.209295
iteration 311: loss 0.650752, mix_dice: 1.092210, mix_ce: 0.209295
iteration 311: loss 0.650752, mix_dice: 1.092210, mix_ce: 0.20

 78%|█████████████████████████▊       | 29/37 [00:55<00:13,  1.75s/it]

iteration 320: loss 0.545418, mix_dice: 0.942664, mix_ce: 0.148172
iteration 320: loss 0.545418, mix_dice: 0.942664, mix_ce: 0.148172
iteration 320: loss 0.545418, mix_dice: 0.942664, mix_ce: 0.148172
iteration 320: loss 0.545418, mix_dice: 0.942664, mix_ce: 0.148172
iteration 320: loss 0.545418, mix_dice: 0.942664, mix_ce: 0.148172
iteration 320: loss 0.545418, mix_dice: 0.942664, mix_ce: 0.148172
iteration 321: loss 0.541028, mix_dice: 0.971243, mix_ce: 0.110813
iteration 321: loss 0.541028, mix_dice: 0.971243, mix_ce: 0.110813
iteration 321: loss 0.541028, mix_dice: 0.971243, mix_ce: 0.110813
iteration 321: loss 0.541028, mix_dice: 0.971243, mix_ce: 0.110813
iteration 321: loss 0.541028, mix_dice: 0.971243, mix_ce: 0.110813
iteration 321: loss 0.541028, mix_dice: 0.971243, mix_ce: 0.110813
iteration 322: loss 0.566131, mix_dice: 0.962269, mix_ce: 0.169993
iteration 322: loss 0.566131, mix_dice: 0.962269, mix_ce: 0.169993
iteration 322: loss 0.566131, mix_dice: 0.962269, mix_ce: 0.16

 81%|██████████████████████████▊      | 30/37 [00:57<00:12,  1.75s/it]

iteration 331: loss 0.586371, mix_dice: 1.026137, mix_ce: 0.146604
iteration 331: loss 0.586371, mix_dice: 1.026137, mix_ce: 0.146604
iteration 331: loss 0.586371, mix_dice: 1.026137, mix_ce: 0.146604
iteration 331: loss 0.586371, mix_dice: 1.026137, mix_ce: 0.146604
iteration 331: loss 0.586371, mix_dice: 1.026137, mix_ce: 0.146604
iteration 331: loss 0.586371, mix_dice: 1.026137, mix_ce: 0.146604
iteration 332: loss 0.550141, mix_dice: 0.923402, mix_ce: 0.176880
iteration 332: loss 0.550141, mix_dice: 0.923402, mix_ce: 0.176880
iteration 332: loss 0.550141, mix_dice: 0.923402, mix_ce: 0.176880
iteration 332: loss 0.550141, mix_dice: 0.923402, mix_ce: 0.176880
iteration 332: loss 0.550141, mix_dice: 0.923402, mix_ce: 0.176880
iteration 332: loss 0.550141, mix_dice: 0.923402, mix_ce: 0.176880
iteration 333: loss 0.522708, mix_dice: 0.916389, mix_ce: 0.129028
iteration 333: loss 0.522708, mix_dice: 0.916389, mix_ce: 0.129028
iteration 333: loss 0.522708, mix_dice: 0.916389, mix_ce: 0.12

 84%|███████████████████████████▋     | 31/37 [00:58<00:10,  1.71s/it]

iteration 342: loss 0.561752, mix_dice: 0.957143, mix_ce: 0.166361
iteration 342: loss 0.561752, mix_dice: 0.957143, mix_ce: 0.166361
iteration 342: loss 0.561752, mix_dice: 0.957143, mix_ce: 0.166361
iteration 342: loss 0.561752, mix_dice: 0.957143, mix_ce: 0.166361
iteration 342: loss 0.561752, mix_dice: 0.957143, mix_ce: 0.166361
iteration 342: loss 0.561752, mix_dice: 0.957143, mix_ce: 0.166361
iteration 343: loss 0.595277, mix_dice: 1.000311, mix_ce: 0.190243
iteration 343: loss 0.595277, mix_dice: 1.000311, mix_ce: 0.190243
iteration 343: loss 0.595277, mix_dice: 1.000311, mix_ce: 0.190243
iteration 343: loss 0.595277, mix_dice: 1.000311, mix_ce: 0.190243
iteration 343: loss 0.595277, mix_dice: 1.000311, mix_ce: 0.190243
iteration 343: loss 0.595277, mix_dice: 1.000311, mix_ce: 0.190243
iteration 344: loss 0.498699, mix_dice: 0.840222, mix_ce: 0.157177
iteration 344: loss 0.498699, mix_dice: 0.840222, mix_ce: 0.157177
iteration 344: loss 0.498699, mix_dice: 0.840222, mix_ce: 0.15

 86%|████████████████████████████▌    | 32/37 [01:00<00:08,  1.67s/it]

iteration 353: loss 0.558006, mix_dice: 0.974400, mix_ce: 0.141611
iteration 353: loss 0.558006, mix_dice: 0.974400, mix_ce: 0.141611
iteration 353: loss 0.558006, mix_dice: 0.974400, mix_ce: 0.141611
iteration 353: loss 0.558006, mix_dice: 0.974400, mix_ce: 0.141611
iteration 353: loss 0.558006, mix_dice: 0.974400, mix_ce: 0.141611
iteration 353: loss 0.558006, mix_dice: 0.974400, mix_ce: 0.141611
iteration 354: loss 0.522005, mix_dice: 0.925140, mix_ce: 0.118870
iteration 354: loss 0.522005, mix_dice: 0.925140, mix_ce: 0.118870
iteration 354: loss 0.522005, mix_dice: 0.925140, mix_ce: 0.118870
iteration 354: loss 0.522005, mix_dice: 0.925140, mix_ce: 0.118870
iteration 354: loss 0.522005, mix_dice: 0.925140, mix_ce: 0.118870
iteration 354: loss 0.522005, mix_dice: 0.925140, mix_ce: 0.118870
iteration 355: loss 0.537837, mix_dice: 0.928367, mix_ce: 0.147307
iteration 355: loss 0.537837, mix_dice: 0.928367, mix_ce: 0.147307
iteration 355: loss 0.537837, mix_dice: 0.928367, mix_ce: 0.14

 89%|█████████████████████████████▍   | 33/37 [01:01<00:06,  1.64s/it]

iteration 364: loss 0.636106, mix_dice: 1.092079, mix_ce: 0.180133
iteration 364: loss 0.636106, mix_dice: 1.092079, mix_ce: 0.180133
iteration 364: loss 0.636106, mix_dice: 1.092079, mix_ce: 0.180133
iteration 364: loss 0.636106, mix_dice: 1.092079, mix_ce: 0.180133
iteration 364: loss 0.636106, mix_dice: 1.092079, mix_ce: 0.180133
iteration 364: loss 0.636106, mix_dice: 1.092079, mix_ce: 0.180133
iteration 365: loss 0.487432, mix_dice: 0.851882, mix_ce: 0.122983
iteration 365: loss 0.487432, mix_dice: 0.851882, mix_ce: 0.122983
iteration 365: loss 0.487432, mix_dice: 0.851882, mix_ce: 0.122983
iteration 365: loss 0.487432, mix_dice: 0.851882, mix_ce: 0.122983
iteration 365: loss 0.487432, mix_dice: 0.851882, mix_ce: 0.122983
iteration 365: loss 0.487432, mix_dice: 0.851882, mix_ce: 0.122983
iteration 366: loss 0.463026, mix_dice: 0.772462, mix_ce: 0.153591
iteration 366: loss 0.463026, mix_dice: 0.772462, mix_ce: 0.153591
iteration 366: loss 0.463026, mix_dice: 0.772462, mix_ce: 0.15

 92%|██████████████████████████████▎  | 34/37 [01:03<00:04,  1.64s/it]

iteration 375: loss 0.648215, mix_dice: 1.097335, mix_ce: 0.199096
iteration 375: loss 0.648215, mix_dice: 1.097335, mix_ce: 0.199096
iteration 375: loss 0.648215, mix_dice: 1.097335, mix_ce: 0.199096
iteration 375: loss 0.648215, mix_dice: 1.097335, mix_ce: 0.199096
iteration 375: loss 0.648215, mix_dice: 1.097335, mix_ce: 0.199096
iteration 375: loss 0.648215, mix_dice: 1.097335, mix_ce: 0.199096
iteration 376: loss 0.615355, mix_dice: 1.044820, mix_ce: 0.185890
iteration 376: loss 0.615355, mix_dice: 1.044820, mix_ce: 0.185890
iteration 376: loss 0.615355, mix_dice: 1.044820, mix_ce: 0.185890
iteration 376: loss 0.615355, mix_dice: 1.044820, mix_ce: 0.185890
iteration 376: loss 0.615355, mix_dice: 1.044820, mix_ce: 0.185890
iteration 376: loss 0.615355, mix_dice: 1.044820, mix_ce: 0.185890
iteration 377: loss 0.537747, mix_dice: 0.868156, mix_ce: 0.207338
iteration 377: loss 0.537747, mix_dice: 0.868156, mix_ce: 0.207338
iteration 377: loss 0.537747, mix_dice: 0.868156, mix_ce: 0.20

 95%|███████████████████████████████▏ | 35/37 [01:05<00:03,  1.63s/it]

iteration 386: loss 0.414998, mix_dice: 0.705980, mix_ce: 0.124015
iteration 386: loss 0.414998, mix_dice: 0.705980, mix_ce: 0.124015
iteration 386: loss 0.414998, mix_dice: 0.705980, mix_ce: 0.124015
iteration 386: loss 0.414998, mix_dice: 0.705980, mix_ce: 0.124015
iteration 386: loss 0.414998, mix_dice: 0.705980, mix_ce: 0.124015
iteration 386: loss 0.414998, mix_dice: 0.705980, mix_ce: 0.124015
iteration 387: loss 0.471897, mix_dice: 0.761034, mix_ce: 0.182760
iteration 387: loss 0.471897, mix_dice: 0.761034, mix_ce: 0.182760
iteration 387: loss 0.471897, mix_dice: 0.761034, mix_ce: 0.182760
iteration 387: loss 0.471897, mix_dice: 0.761034, mix_ce: 0.182760
iteration 387: loss 0.471897, mix_dice: 0.761034, mix_ce: 0.182760
iteration 387: loss 0.471897, mix_dice: 0.761034, mix_ce: 0.182760
iteration 388: loss 0.597243, mix_dice: 1.015693, mix_ce: 0.178793
iteration 388: loss 0.597243, mix_dice: 1.015693, mix_ce: 0.178793
iteration 388: loss 0.597243, mix_dice: 1.015693, mix_ce: 0.17

 97%|████████████████████████████████ | 36/37 [01:06<00:01,  1.61s/it]

iteration 397: loss 0.561666, mix_dice: 1.013662, mix_ce: 0.109669
iteration 397: loss 0.561666, mix_dice: 1.013662, mix_ce: 0.109669
iteration 397: loss 0.561666, mix_dice: 1.013662, mix_ce: 0.109669
iteration 397: loss 0.561666, mix_dice: 1.013662, mix_ce: 0.109669
iteration 397: loss 0.561666, mix_dice: 1.013662, mix_ce: 0.109669
iteration 397: loss 0.561666, mix_dice: 1.013662, mix_ce: 0.109669
iteration 398: loss 0.615856, mix_dice: 1.025691, mix_ce: 0.206021
iteration 398: loss 0.615856, mix_dice: 1.025691, mix_ce: 0.206021
iteration 398: loss 0.615856, mix_dice: 1.025691, mix_ce: 0.206021
iteration 398: loss 0.615856, mix_dice: 1.025691, mix_ce: 0.206021
iteration 398: loss 0.615856, mix_dice: 1.025691, mix_ce: 0.206021
iteration 398: loss 0.615856, mix_dice: 1.025691, mix_ce: 0.206021
iteration 399: loss 0.564083, mix_dice: 0.992507, mix_ce: 0.135658
iteration 399: loss 0.564083, mix_dice: 0.992507, mix_ce: 0.135658
iteration 399: loss 0.564083, mix_dice: 0.992507, mix_ce: 0.13

 97%|████████████████████████████████ | 36/37 [01:15<00:02,  2.11s/it]

Total slices: 1312
Total slices: 20
Total slices is: 1312, labeled slices is:136
Loaded from ./model/BCP/ACDC_BCP_7_labeled/pretrain/unet_best_model.pth
Loaded from ./model/BCP/ACDC_BCP_7_labeled/pretrain/unet_best_model.pth
Loaded from ./model/BCP/ACDC_BCP_7_labeled/pretrain/unet_best_model.pth
Loaded from ./model/BCP/ACDC_BCP_7_labeled/pretrain/unet_best_model.pth
Loaded from ./model/BCP/ACDC_BCP_7_labeled/pretrain/unet_best_model.pth
Loaded from ./model/BCP/ACDC_BCP_7_labeled/pretrain/unet_best_model.pth
Loaded from ./model/BCP/ACDC_BCP_7_labeled/pretrain/unet_best_model.pth
Start self_training
Start self_training
Start self_training
Start self_training
Start self_training
Start self_training
Start self_training
11 iterations per epoch
11 iterations per epoch
11 iterations per epoch
11 iterations per epoch
11 iterations per epoch
11 iterations per epoch
11 iterations per epoch



  state = torch.load(str(path))
  state = torch.load(str(path))
  0%|                                          | 0/37 [00:00<?, ?it/s]

iteration 1: loss: 0.705734, mix_dice: 1.173441, mix_ce: 0.238027
iteration 1: loss: 0.705734, mix_dice: 1.173441, mix_ce: 0.238027
iteration 1: loss: 0.705734, mix_dice: 1.173441, mix_ce: 0.238027
iteration 1: loss: 0.705734, mix_dice: 1.173441, mix_ce: 0.238027
iteration 1: loss: 0.705734, mix_dice: 1.173441, mix_ce: 0.238027
iteration 1: loss: 0.705734, mix_dice: 1.173441, mix_ce: 0.238027
iteration 1: loss: 0.705734, mix_dice: 1.173441, mix_ce: 0.238027
iteration 2: loss: 0.813207, mix_dice: 1.374241, mix_ce: 0.252173
iteration 2: loss: 0.813207, mix_dice: 1.374241, mix_ce: 0.252173
iteration 2: loss: 0.813207, mix_dice: 1.374241, mix_ce: 0.252173
iteration 2: loss: 0.813207, mix_dice: 1.374241, mix_ce: 0.252173
iteration 2: loss: 0.813207, mix_dice: 1.374241, mix_ce: 0.252173
iteration 2: loss: 0.813207, mix_dice: 1.374241, mix_ce: 0.252173
iteration 2: loss: 0.813207, mix_dice: 1.374241, mix_ce: 0.252173
iteration 3: loss: 0.800984, mix_dice: 1.370526, mix_ce: 0.231441
iteration 

  3%|▉                                 | 1/37 [00:04<02:46,  4.62s/it]

iteration 12: loss: 0.804342, mix_dice: 1.435995, mix_ce: 0.172689
iteration 12: loss: 0.804342, mix_dice: 1.435995, mix_ce: 0.172689
iteration 12: loss: 0.804342, mix_dice: 1.435995, mix_ce: 0.172689
iteration 12: loss: 0.804342, mix_dice: 1.435995, mix_ce: 0.172689
iteration 12: loss: 0.804342, mix_dice: 1.435995, mix_ce: 0.172689
iteration 12: loss: 0.804342, mix_dice: 1.435995, mix_ce: 0.172689
iteration 12: loss: 0.804342, mix_dice: 1.435995, mix_ce: 0.172689
iteration 13: loss: 0.843415, mix_dice: 1.458537, mix_ce: 0.228292
iteration 13: loss: 0.843415, mix_dice: 1.458537, mix_ce: 0.228292
iteration 13: loss: 0.843415, mix_dice: 1.458537, mix_ce: 0.228292
iteration 13: loss: 0.843415, mix_dice: 1.458537, mix_ce: 0.228292
iteration 13: loss: 0.843415, mix_dice: 1.458537, mix_ce: 0.228292
iteration 13: loss: 0.843415, mix_dice: 1.458537, mix_ce: 0.228292
iteration 13: loss: 0.843415, mix_dice: 1.458537, mix_ce: 0.228292
iteration 14: loss: 0.775965, mix_dice: 1.377857, mix_ce: 0.17

  5%|█▊                                | 2/37 [00:09<02:39,  4.56s/it]

iteration 23: loss: 0.752032, mix_dice: 1.317869, mix_ce: 0.186196
iteration 23: loss: 0.752032, mix_dice: 1.317869, mix_ce: 0.186196
iteration 23: loss: 0.752032, mix_dice: 1.317869, mix_ce: 0.186196
iteration 23: loss: 0.752032, mix_dice: 1.317869, mix_ce: 0.186196
iteration 23: loss: 0.752032, mix_dice: 1.317869, mix_ce: 0.186196
iteration 23: loss: 0.752032, mix_dice: 1.317869, mix_ce: 0.186196
iteration 23: loss: 0.752032, mix_dice: 1.317869, mix_ce: 0.186196
iteration 24: loss: 0.828614, mix_dice: 1.393964, mix_ce: 0.263264
iteration 24: loss: 0.828614, mix_dice: 1.393964, mix_ce: 0.263264
iteration 24: loss: 0.828614, mix_dice: 1.393964, mix_ce: 0.263264
iteration 24: loss: 0.828614, mix_dice: 1.393964, mix_ce: 0.263264
iteration 24: loss: 0.828614, mix_dice: 1.393964, mix_ce: 0.263264
iteration 24: loss: 0.828614, mix_dice: 1.393964, mix_ce: 0.263264
iteration 24: loss: 0.828614, mix_dice: 1.393964, mix_ce: 0.263264
iteration 25: loss: 0.915697, mix_dice: 1.516899, mix_ce: 0.31

  8%|██▊                               | 3/37 [00:13<02:35,  4.57s/it]

iteration 34: loss: 0.730091, mix_dice: 1.252662, mix_ce: 0.207520
iteration 34: loss: 0.730091, mix_dice: 1.252662, mix_ce: 0.207520
iteration 34: loss: 0.730091, mix_dice: 1.252662, mix_ce: 0.207520
iteration 34: loss: 0.730091, mix_dice: 1.252662, mix_ce: 0.207520
iteration 34: loss: 0.730091, mix_dice: 1.252662, mix_ce: 0.207520
iteration 34: loss: 0.730091, mix_dice: 1.252662, mix_ce: 0.207520
iteration 34: loss: 0.730091, mix_dice: 1.252662, mix_ce: 0.207520
iteration 35: loss: 0.835542, mix_dice: 1.439630, mix_ce: 0.231454
iteration 35: loss: 0.835542, mix_dice: 1.439630, mix_ce: 0.231454
iteration 35: loss: 0.835542, mix_dice: 1.439630, mix_ce: 0.231454
iteration 35: loss: 0.835542, mix_dice: 1.439630, mix_ce: 0.231454
iteration 35: loss: 0.835542, mix_dice: 1.439630, mix_ce: 0.231454
iteration 35: loss: 0.835542, mix_dice: 1.439630, mix_ce: 0.231454
iteration 35: loss: 0.835542, mix_dice: 1.439630, mix_ce: 0.231454
iteration 36: loss: 0.785418, mix_dice: 1.392186, mix_ce: 0.17

 11%|███▋                              | 4/37 [00:18<02:29,  4.52s/it]

iteration 45: loss: 0.782068, mix_dice: 1.281458, mix_ce: 0.282677
iteration 45: loss: 0.782068, mix_dice: 1.281458, mix_ce: 0.282677
iteration 45: loss: 0.782068, mix_dice: 1.281458, mix_ce: 0.282677
iteration 45: loss: 0.782068, mix_dice: 1.281458, mix_ce: 0.282677
iteration 45: loss: 0.782068, mix_dice: 1.281458, mix_ce: 0.282677
iteration 45: loss: 0.782068, mix_dice: 1.281458, mix_ce: 0.282677
iteration 45: loss: 0.782068, mix_dice: 1.281458, mix_ce: 0.282677
iteration 46: loss: 0.494175, mix_dice: 0.802492, mix_ce: 0.185857
iteration 46: loss: 0.494175, mix_dice: 0.802492, mix_ce: 0.185857
iteration 46: loss: 0.494175, mix_dice: 0.802492, mix_ce: 0.185857
iteration 46: loss: 0.494175, mix_dice: 0.802492, mix_ce: 0.185857
iteration 46: loss: 0.494175, mix_dice: 0.802492, mix_ce: 0.185857
iteration 46: loss: 0.494175, mix_dice: 0.802492, mix_ce: 0.185857
iteration 46: loss: 0.494175, mix_dice: 0.802492, mix_ce: 0.185857
iteration 47: loss: 0.728657, mix_dice: 1.281264, mix_ce: 0.17

 14%|████▌                             | 5/37 [00:22<02:24,  4.51s/it]

iteration 56: loss: 0.720789, mix_dice: 1.212069, mix_ce: 0.229510
iteration 56: loss: 0.720789, mix_dice: 1.212069, mix_ce: 0.229510
iteration 56: loss: 0.720789, mix_dice: 1.212069, mix_ce: 0.229510
iteration 56: loss: 0.720789, mix_dice: 1.212069, mix_ce: 0.229510
iteration 56: loss: 0.720789, mix_dice: 1.212069, mix_ce: 0.229510
iteration 56: loss: 0.720789, mix_dice: 1.212069, mix_ce: 0.229510
iteration 56: loss: 0.720789, mix_dice: 1.212069, mix_ce: 0.229510
iteration 57: loss: 0.784072, mix_dice: 1.438459, mix_ce: 0.129686
iteration 57: loss: 0.784072, mix_dice: 1.438459, mix_ce: 0.129686
iteration 57: loss: 0.784072, mix_dice: 1.438459, mix_ce: 0.129686
iteration 57: loss: 0.784072, mix_dice: 1.438459, mix_ce: 0.129686
iteration 57: loss: 0.784072, mix_dice: 1.438459, mix_ce: 0.129686
iteration 57: loss: 0.784072, mix_dice: 1.438459, mix_ce: 0.129686
iteration 57: loss: 0.784072, mix_dice: 1.438459, mix_ce: 0.129686
iteration 58: loss: 0.719660, mix_dice: 1.316261, mix_ce: 0.12

 16%|█████▌                            | 6/37 [00:27<02:20,  4.53s/it]

iteration 67: loss: 0.752443, mix_dice: 1.325784, mix_ce: 0.179103
iteration 67: loss: 0.752443, mix_dice: 1.325784, mix_ce: 0.179103
iteration 67: loss: 0.752443, mix_dice: 1.325784, mix_ce: 0.179103
iteration 67: loss: 0.752443, mix_dice: 1.325784, mix_ce: 0.179103
iteration 67: loss: 0.752443, mix_dice: 1.325784, mix_ce: 0.179103
iteration 67: loss: 0.752443, mix_dice: 1.325784, mix_ce: 0.179103
iteration 67: loss: 0.752443, mix_dice: 1.325784, mix_ce: 0.179103
iteration 68: loss: 0.721144, mix_dice: 1.225314, mix_ce: 0.216975
iteration 68: loss: 0.721144, mix_dice: 1.225314, mix_ce: 0.216975
iteration 68: loss: 0.721144, mix_dice: 1.225314, mix_ce: 0.216975
iteration 68: loss: 0.721144, mix_dice: 1.225314, mix_ce: 0.216975
iteration 68: loss: 0.721144, mix_dice: 1.225314, mix_ce: 0.216975
iteration 68: loss: 0.721144, mix_dice: 1.225314, mix_ce: 0.216975
iteration 68: loss: 0.721144, mix_dice: 1.225314, mix_ce: 0.216975
iteration 69: loss: 0.773512, mix_dice: 1.316700, mix_ce: 0.23

 19%|██████▍                           | 7/37 [00:31<02:15,  4.50s/it]

iteration 78: loss: 0.755227, mix_dice: 1.365022, mix_ce: 0.145432
iteration 78: loss: 0.755227, mix_dice: 1.365022, mix_ce: 0.145432
iteration 78: loss: 0.755227, mix_dice: 1.365022, mix_ce: 0.145432
iteration 78: loss: 0.755227, mix_dice: 1.365022, mix_ce: 0.145432
iteration 78: loss: 0.755227, mix_dice: 1.365022, mix_ce: 0.145432
iteration 78: loss: 0.755227, mix_dice: 1.365022, mix_ce: 0.145432
iteration 78: loss: 0.755227, mix_dice: 1.365022, mix_ce: 0.145432
iteration 79: loss: 0.778309, mix_dice: 1.377711, mix_ce: 0.178908
iteration 79: loss: 0.778309, mix_dice: 1.377711, mix_ce: 0.178908
iteration 79: loss: 0.778309, mix_dice: 1.377711, mix_ce: 0.178908
iteration 79: loss: 0.778309, mix_dice: 1.377711, mix_ce: 0.178908
iteration 79: loss: 0.778309, mix_dice: 1.377711, mix_ce: 0.178908
iteration 79: loss: 0.778309, mix_dice: 1.377711, mix_ce: 0.178908
iteration 79: loss: 0.778309, mix_dice: 1.377711, mix_ce: 0.178908
iteration 80: loss: 0.810393, mix_dice: 1.426426, mix_ce: 0.19

 22%|███████▎                          | 8/37 [00:36<02:10,  4.51s/it]

iteration 89: loss: 0.728059, mix_dice: 1.273998, mix_ce: 0.182119
iteration 89: loss: 0.728059, mix_dice: 1.273998, mix_ce: 0.182119
iteration 89: loss: 0.728059, mix_dice: 1.273998, mix_ce: 0.182119
iteration 89: loss: 0.728059, mix_dice: 1.273998, mix_ce: 0.182119
iteration 89: loss: 0.728059, mix_dice: 1.273998, mix_ce: 0.182119
iteration 89: loss: 0.728059, mix_dice: 1.273998, mix_ce: 0.182119
iteration 89: loss: 0.728059, mix_dice: 1.273998, mix_ce: 0.182119
iteration 90: loss: 0.792701, mix_dice: 1.374296, mix_ce: 0.211106
iteration 90: loss: 0.792701, mix_dice: 1.374296, mix_ce: 0.211106
iteration 90: loss: 0.792701, mix_dice: 1.374296, mix_ce: 0.211106
iteration 90: loss: 0.792701, mix_dice: 1.374296, mix_ce: 0.211106
iteration 90: loss: 0.792701, mix_dice: 1.374296, mix_ce: 0.211106
iteration 90: loss: 0.792701, mix_dice: 1.374296, mix_ce: 0.211106
iteration 90: loss: 0.792701, mix_dice: 1.374296, mix_ce: 0.211106
iteration 91: loss: 0.757396, mix_dice: 1.352878, mix_ce: 0.16

 24%|████████▎                         | 9/37 [00:40<02:05,  4.47s/it]

iteration 100: loss: 0.781123, mix_dice: 1.391029, mix_ce: 0.171216
iteration 100: loss: 0.781123, mix_dice: 1.391029, mix_ce: 0.171216
iteration 100: loss: 0.781123, mix_dice: 1.391029, mix_ce: 0.171216
iteration 100: loss: 0.781123, mix_dice: 1.391029, mix_ce: 0.171216
iteration 100: loss: 0.781123, mix_dice: 1.391029, mix_ce: 0.171216
iteration 100: loss: 0.781123, mix_dice: 1.391029, mix_ce: 0.171216
iteration 100: loss: 0.781123, mix_dice: 1.391029, mix_ce: 0.171216
iteration 101: loss: 0.804364, mix_dice: 1.422450, mix_ce: 0.186277
iteration 101: loss: 0.804364, mix_dice: 1.422450, mix_ce: 0.186277
iteration 101: loss: 0.804364, mix_dice: 1.422450, mix_ce: 0.186277
iteration 101: loss: 0.804364, mix_dice: 1.422450, mix_ce: 0.186277
iteration 101: loss: 0.804364, mix_dice: 1.422450, mix_ce: 0.186277
iteration 101: loss: 0.804364, mix_dice: 1.422450, mix_ce: 0.186277
iteration 101: loss: 0.804364, mix_dice: 1.422450, mix_ce: 0.186277
iteration 102: loss: 0.803247, mix_dice: 1.43274

 27%|████████▉                        | 10/37 [00:45<02:02,  4.54s/it]

iteration 111: loss: 0.793971, mix_dice: 1.410794, mix_ce: 0.177149
iteration 111: loss: 0.793971, mix_dice: 1.410794, mix_ce: 0.177149
iteration 111: loss: 0.793971, mix_dice: 1.410794, mix_ce: 0.177149
iteration 111: loss: 0.793971, mix_dice: 1.410794, mix_ce: 0.177149
iteration 111: loss: 0.793971, mix_dice: 1.410794, mix_ce: 0.177149
iteration 111: loss: 0.793971, mix_dice: 1.410794, mix_ce: 0.177149
iteration 111: loss: 0.793971, mix_dice: 1.410794, mix_ce: 0.177149
iteration 112: loss: 0.763720, mix_dice: 1.360216, mix_ce: 0.167223
iteration 112: loss: 0.763720, mix_dice: 1.360216, mix_ce: 0.167223
iteration 112: loss: 0.763720, mix_dice: 1.360216, mix_ce: 0.167223
iteration 112: loss: 0.763720, mix_dice: 1.360216, mix_ce: 0.167223
iteration 112: loss: 0.763720, mix_dice: 1.360216, mix_ce: 0.167223
iteration 112: loss: 0.763720, mix_dice: 1.360216, mix_ce: 0.167223
iteration 112: loss: 0.763720, mix_dice: 1.360216, mix_ce: 0.167223
iteration 113: loss: 0.891090, mix_dice: 1.55377

 30%|█████████▊                       | 11/37 [00:49<01:57,  4.53s/it]

iteration 122: loss: 0.875557, mix_dice: 1.504773, mix_ce: 0.246340
iteration 122: loss: 0.875557, mix_dice: 1.504773, mix_ce: 0.246340
iteration 122: loss: 0.875557, mix_dice: 1.504773, mix_ce: 0.246340
iteration 122: loss: 0.875557, mix_dice: 1.504773, mix_ce: 0.246340
iteration 122: loss: 0.875557, mix_dice: 1.504773, mix_ce: 0.246340
iteration 122: loss: 0.875557, mix_dice: 1.504773, mix_ce: 0.246340
iteration 122: loss: 0.875557, mix_dice: 1.504773, mix_ce: 0.246340
iteration 123: loss: 0.753727, mix_dice: 1.337033, mix_ce: 0.170422
iteration 123: loss: 0.753727, mix_dice: 1.337033, mix_ce: 0.170422
iteration 123: loss: 0.753727, mix_dice: 1.337033, mix_ce: 0.170422
iteration 123: loss: 0.753727, mix_dice: 1.337033, mix_ce: 0.170422
iteration 123: loss: 0.753727, mix_dice: 1.337033, mix_ce: 0.170422
iteration 123: loss: 0.753727, mix_dice: 1.337033, mix_ce: 0.170422
iteration 123: loss: 0.753727, mix_dice: 1.337033, mix_ce: 0.170422
iteration 124: loss: 0.636247, mix_dice: 1.02013

 32%|██████████▋                      | 12/37 [00:54<01:52,  4.50s/it]

iteration 133: loss: 0.812735, mix_dice: 1.413704, mix_ce: 0.211766
iteration 133: loss: 0.812735, mix_dice: 1.413704, mix_ce: 0.211766
iteration 133: loss: 0.812735, mix_dice: 1.413704, mix_ce: 0.211766
iteration 133: loss: 0.812735, mix_dice: 1.413704, mix_ce: 0.211766
iteration 133: loss: 0.812735, mix_dice: 1.413704, mix_ce: 0.211766
iteration 133: loss: 0.812735, mix_dice: 1.413704, mix_ce: 0.211766
iteration 133: loss: 0.812735, mix_dice: 1.413704, mix_ce: 0.211766
iteration 134: loss: 0.749159, mix_dice: 1.255905, mix_ce: 0.242413
iteration 134: loss: 0.749159, mix_dice: 1.255905, mix_ce: 0.242413
iteration 134: loss: 0.749159, mix_dice: 1.255905, mix_ce: 0.242413
iteration 134: loss: 0.749159, mix_dice: 1.255905, mix_ce: 0.242413
iteration 134: loss: 0.749159, mix_dice: 1.255905, mix_ce: 0.242413
iteration 134: loss: 0.749159, mix_dice: 1.255905, mix_ce: 0.242413
iteration 134: loss: 0.749159, mix_dice: 1.255905, mix_ce: 0.242413
iteration 135: loss: 0.845545, mix_dice: 1.48223

 35%|███████████▌                     | 13/37 [00:58<01:48,  4.52s/it]

iteration 144: loss: 0.771527, mix_dice: 1.385912, mix_ce: 0.157143
iteration 144: loss: 0.771527, mix_dice: 1.385912, mix_ce: 0.157143
iteration 144: loss: 0.771527, mix_dice: 1.385912, mix_ce: 0.157143
iteration 144: loss: 0.771527, mix_dice: 1.385912, mix_ce: 0.157143
iteration 144: loss: 0.771527, mix_dice: 1.385912, mix_ce: 0.157143
iteration 144: loss: 0.771527, mix_dice: 1.385912, mix_ce: 0.157143
iteration 144: loss: 0.771527, mix_dice: 1.385912, mix_ce: 0.157143
iteration 145: loss: 0.697551, mix_dice: 1.266752, mix_ce: 0.128351
iteration 145: loss: 0.697551, mix_dice: 1.266752, mix_ce: 0.128351
iteration 145: loss: 0.697551, mix_dice: 1.266752, mix_ce: 0.128351
iteration 145: loss: 0.697551, mix_dice: 1.266752, mix_ce: 0.128351
iteration 145: loss: 0.697551, mix_dice: 1.266752, mix_ce: 0.128351
iteration 145: loss: 0.697551, mix_dice: 1.266752, mix_ce: 0.128351
iteration 145: loss: 0.697551, mix_dice: 1.266752, mix_ce: 0.128351
iteration 146: loss: 0.737607, mix_dice: 1.34221

 38%|████████████▍                    | 14/37 [01:03<01:43,  4.50s/it]

iteration 155: loss: 0.767837, mix_dice: 1.361139, mix_ce: 0.174535
iteration 155: loss: 0.767837, mix_dice: 1.361139, mix_ce: 0.174535
iteration 155: loss: 0.767837, mix_dice: 1.361139, mix_ce: 0.174535
iteration 155: loss: 0.767837, mix_dice: 1.361139, mix_ce: 0.174535
iteration 155: loss: 0.767837, mix_dice: 1.361139, mix_ce: 0.174535
iteration 155: loss: 0.767837, mix_dice: 1.361139, mix_ce: 0.174535
iteration 155: loss: 0.767837, mix_dice: 1.361139, mix_ce: 0.174535
iteration 156: loss: 0.730021, mix_dice: 1.283902, mix_ce: 0.176140
iteration 156: loss: 0.730021, mix_dice: 1.283902, mix_ce: 0.176140
iteration 156: loss: 0.730021, mix_dice: 1.283902, mix_ce: 0.176140
iteration 156: loss: 0.730021, mix_dice: 1.283902, mix_ce: 0.176140
iteration 156: loss: 0.730021, mix_dice: 1.283902, mix_ce: 0.176140
iteration 156: loss: 0.730021, mix_dice: 1.283902, mix_ce: 0.176140
iteration 156: loss: 0.730021, mix_dice: 1.283902, mix_ce: 0.176140
iteration 157: loss: 0.776518, mix_dice: 1.39220

 41%|█████████████▍                   | 15/37 [01:07<01:39,  4.51s/it]

iteration 166: loss: 0.726249, mix_dice: 1.323571, mix_ce: 0.128926
iteration 166: loss: 0.726249, mix_dice: 1.323571, mix_ce: 0.128926
iteration 166: loss: 0.726249, mix_dice: 1.323571, mix_ce: 0.128926
iteration 166: loss: 0.726249, mix_dice: 1.323571, mix_ce: 0.128926
iteration 166: loss: 0.726249, mix_dice: 1.323571, mix_ce: 0.128926
iteration 166: loss: 0.726249, mix_dice: 1.323571, mix_ce: 0.128926
iteration 166: loss: 0.726249, mix_dice: 1.323571, mix_ce: 0.128926
iteration 167: loss: 0.791367, mix_dice: 1.415154, mix_ce: 0.167581
iteration 167: loss: 0.791367, mix_dice: 1.415154, mix_ce: 0.167581
iteration 167: loss: 0.791367, mix_dice: 1.415154, mix_ce: 0.167581
iteration 167: loss: 0.791367, mix_dice: 1.415154, mix_ce: 0.167581
iteration 167: loss: 0.791367, mix_dice: 1.415154, mix_ce: 0.167581
iteration 167: loss: 0.791367, mix_dice: 1.415154, mix_ce: 0.167581
iteration 167: loss: 0.791367, mix_dice: 1.415154, mix_ce: 0.167581
iteration 168: loss: 0.734259, mix_dice: 1.33142

 43%|██████████████▎                  | 16/37 [01:12<01:34,  4.51s/it]

iteration 177: loss: 0.730963, mix_dice: 1.284147, mix_ce: 0.177778
iteration 177: loss: 0.730963, mix_dice: 1.284147, mix_ce: 0.177778
iteration 177: loss: 0.730963, mix_dice: 1.284147, mix_ce: 0.177778
iteration 177: loss: 0.730963, mix_dice: 1.284147, mix_ce: 0.177778
iteration 177: loss: 0.730963, mix_dice: 1.284147, mix_ce: 0.177778
iteration 177: loss: 0.730963, mix_dice: 1.284147, mix_ce: 0.177778
iteration 177: loss: 0.730963, mix_dice: 1.284147, mix_ce: 0.177778
iteration 178: loss: 0.685099, mix_dice: 1.229893, mix_ce: 0.140306
iteration 178: loss: 0.685099, mix_dice: 1.229893, mix_ce: 0.140306
iteration 178: loss: 0.685099, mix_dice: 1.229893, mix_ce: 0.140306
iteration 178: loss: 0.685099, mix_dice: 1.229893, mix_ce: 0.140306
iteration 178: loss: 0.685099, mix_dice: 1.229893, mix_ce: 0.140306
iteration 178: loss: 0.685099, mix_dice: 1.229893, mix_ce: 0.140306
iteration 178: loss: 0.685099, mix_dice: 1.229893, mix_ce: 0.140306
iteration 179: loss: 0.716653, mix_dice: 1.29998

 46%|███████████████▏                 | 17/37 [01:16<01:30,  4.55s/it]

iteration 188: loss: 0.786434, mix_dice: 1.422298, mix_ce: 0.150570
iteration 188: loss: 0.786434, mix_dice: 1.422298, mix_ce: 0.150570
iteration 188: loss: 0.786434, mix_dice: 1.422298, mix_ce: 0.150570
iteration 188: loss: 0.786434, mix_dice: 1.422298, mix_ce: 0.150570
iteration 188: loss: 0.786434, mix_dice: 1.422298, mix_ce: 0.150570
iteration 188: loss: 0.786434, mix_dice: 1.422298, mix_ce: 0.150570
iteration 188: loss: 0.786434, mix_dice: 1.422298, mix_ce: 0.150570
iteration 189: loss: 0.722948, mix_dice: 1.326747, mix_ce: 0.119149
iteration 189: loss: 0.722948, mix_dice: 1.326747, mix_ce: 0.119149
iteration 189: loss: 0.722948, mix_dice: 1.326747, mix_ce: 0.119149
iteration 189: loss: 0.722948, mix_dice: 1.326747, mix_ce: 0.119149
iteration 189: loss: 0.722948, mix_dice: 1.326747, mix_ce: 0.119149
iteration 189: loss: 0.722948, mix_dice: 1.326747, mix_ce: 0.119149
iteration 189: loss: 0.722948, mix_dice: 1.326747, mix_ce: 0.119149
iteration 190: loss: 0.754522, mix_dice: 1.35504

 49%|████████████████                 | 18/37 [01:21<01:25,  4.52s/it]

iteration 199: loss: 0.763280, mix_dice: 1.358388, mix_ce: 0.168171
iteration 199: loss: 0.763280, mix_dice: 1.358388, mix_ce: 0.168171
iteration 199: loss: 0.763280, mix_dice: 1.358388, mix_ce: 0.168171
iteration 199: loss: 0.763280, mix_dice: 1.358388, mix_ce: 0.168171
iteration 199: loss: 0.763280, mix_dice: 1.358388, mix_ce: 0.168171
iteration 199: loss: 0.763280, mix_dice: 1.358388, mix_ce: 0.168171
iteration 199: loss: 0.763280, mix_dice: 1.358388, mix_ce: 0.168171
iteration 200: loss: 0.650187, mix_dice: 1.131390, mix_ce: 0.168984
iteration 200: loss: 0.650187, mix_dice: 1.131390, mix_ce: 0.168984
iteration 200: loss: 0.650187, mix_dice: 1.131390, mix_ce: 0.168984
iteration 200: loss: 0.650187, mix_dice: 1.131390, mix_ce: 0.168984
iteration 200: loss: 0.650187, mix_dice: 1.131390, mix_ce: 0.168984
iteration 200: loss: 0.650187, mix_dice: 1.131390, mix_ce: 0.168984
iteration 200: loss: 0.650187, mix_dice: 1.131390, mix_ce: 0.168984
Metric list: [[ 0.54705725 17.05745242]
 [ 0.649

 51%|████████████████▉                | 19/37 [01:33<02:04,  6.94s/it]

iteration 210: loss: 0.652453, mix_dice: 1.132675, mix_ce: 0.172231
iteration 210: loss: 0.652453, mix_dice: 1.132675, mix_ce: 0.172231
iteration 210: loss: 0.652453, mix_dice: 1.132675, mix_ce: 0.172231
iteration 210: loss: 0.652453, mix_dice: 1.132675, mix_ce: 0.172231
iteration 210: loss: 0.652453, mix_dice: 1.132675, mix_ce: 0.172231
iteration 210: loss: 0.652453, mix_dice: 1.132675, mix_ce: 0.172231
iteration 210: loss: 0.652453, mix_dice: 1.132675, mix_ce: 0.172231
iteration 211: loss: 0.676136, mix_dice: 1.207670, mix_ce: 0.144601
iteration 211: loss: 0.676136, mix_dice: 1.207670, mix_ce: 0.144601
iteration 211: loss: 0.676136, mix_dice: 1.207670, mix_ce: 0.144601
iteration 211: loss: 0.676136, mix_dice: 1.207670, mix_ce: 0.144601
iteration 211: loss: 0.676136, mix_dice: 1.207670, mix_ce: 0.144601
iteration 211: loss: 0.676136, mix_dice: 1.207670, mix_ce: 0.144601
iteration 211: loss: 0.676136, mix_dice: 1.207670, mix_ce: 0.144601
iteration 212: loss: 0.824720, mix_dice: 1.45817

 54%|█████████████████▊               | 20/37 [01:38<01:45,  6.23s/it]

iteration 221: loss: 0.639116, mix_dice: 1.147562, mix_ce: 0.130671
iteration 221: loss: 0.639116, mix_dice: 1.147562, mix_ce: 0.130671
iteration 221: loss: 0.639116, mix_dice: 1.147562, mix_ce: 0.130671
iteration 221: loss: 0.639116, mix_dice: 1.147562, mix_ce: 0.130671
iteration 221: loss: 0.639116, mix_dice: 1.147562, mix_ce: 0.130671
iteration 221: loss: 0.639116, mix_dice: 1.147562, mix_ce: 0.130671
iteration 221: loss: 0.639116, mix_dice: 1.147562, mix_ce: 0.130671
iteration 222: loss: 0.691272, mix_dice: 1.263793, mix_ce: 0.118752
iteration 222: loss: 0.691272, mix_dice: 1.263793, mix_ce: 0.118752
iteration 222: loss: 0.691272, mix_dice: 1.263793, mix_ce: 0.118752
iteration 222: loss: 0.691272, mix_dice: 1.263793, mix_ce: 0.118752
iteration 222: loss: 0.691272, mix_dice: 1.263793, mix_ce: 0.118752
iteration 222: loss: 0.691272, mix_dice: 1.263793, mix_ce: 0.118752
iteration 222: loss: 0.691272, mix_dice: 1.263793, mix_ce: 0.118752
iteration 223: loss: 0.706214, mix_dice: 1.26694

 57%|██████████████████▋              | 21/37 [01:42<01:31,  5.71s/it]

iteration 232: loss: 0.684019, mix_dice: 1.225533, mix_ce: 0.142504
iteration 232: loss: 0.684019, mix_dice: 1.225533, mix_ce: 0.142504
iteration 232: loss: 0.684019, mix_dice: 1.225533, mix_ce: 0.142504
iteration 232: loss: 0.684019, mix_dice: 1.225533, mix_ce: 0.142504
iteration 232: loss: 0.684019, mix_dice: 1.225533, mix_ce: 0.142504
iteration 232: loss: 0.684019, mix_dice: 1.225533, mix_ce: 0.142504
iteration 232: loss: 0.684019, mix_dice: 1.225533, mix_ce: 0.142504
iteration 233: loss: 0.739461, mix_dice: 1.294288, mix_ce: 0.184635
iteration 233: loss: 0.739461, mix_dice: 1.294288, mix_ce: 0.184635
iteration 233: loss: 0.739461, mix_dice: 1.294288, mix_ce: 0.184635
iteration 233: loss: 0.739461, mix_dice: 1.294288, mix_ce: 0.184635
iteration 233: loss: 0.739461, mix_dice: 1.294288, mix_ce: 0.184635
iteration 233: loss: 0.739461, mix_dice: 1.294288, mix_ce: 0.184635
iteration 233: loss: 0.739461, mix_dice: 1.294288, mix_ce: 0.184635
iteration 234: loss: 0.796183, mix_dice: 1.40645

 59%|███████████████████▌             | 22/37 [01:47<01:20,  5.38s/it]

iteration 243: loss: 0.634784, mix_dice: 1.153841, mix_ce: 0.115727
iteration 243: loss: 0.634784, mix_dice: 1.153841, mix_ce: 0.115727
iteration 243: loss: 0.634784, mix_dice: 1.153841, mix_ce: 0.115727
iteration 243: loss: 0.634784, mix_dice: 1.153841, mix_ce: 0.115727
iteration 243: loss: 0.634784, mix_dice: 1.153841, mix_ce: 0.115727
iteration 243: loss: 0.634784, mix_dice: 1.153841, mix_ce: 0.115727
iteration 243: loss: 0.634784, mix_dice: 1.153841, mix_ce: 0.115727
iteration 244: loss: 0.773567, mix_dice: 1.398362, mix_ce: 0.148772
iteration 244: loss: 0.773567, mix_dice: 1.398362, mix_ce: 0.148772
iteration 244: loss: 0.773567, mix_dice: 1.398362, mix_ce: 0.148772
iteration 244: loss: 0.773567, mix_dice: 1.398362, mix_ce: 0.148772
iteration 244: loss: 0.773567, mix_dice: 1.398362, mix_ce: 0.148772
iteration 244: loss: 0.773567, mix_dice: 1.398362, mix_ce: 0.148772
iteration 244: loss: 0.773567, mix_dice: 1.398362, mix_ce: 0.148772
iteration 245: loss: 0.714217, mix_dice: 1.31323

 62%|████████████████████▌            | 23/37 [01:52<01:11,  5.09s/it]

iteration 254: loss: 0.674467, mix_dice: 1.182250, mix_ce: 0.166684
iteration 254: loss: 0.674467, mix_dice: 1.182250, mix_ce: 0.166684
iteration 254: loss: 0.674467, mix_dice: 1.182250, mix_ce: 0.166684
iteration 254: loss: 0.674467, mix_dice: 1.182250, mix_ce: 0.166684
iteration 254: loss: 0.674467, mix_dice: 1.182250, mix_ce: 0.166684
iteration 254: loss: 0.674467, mix_dice: 1.182250, mix_ce: 0.166684
iteration 254: loss: 0.674467, mix_dice: 1.182250, mix_ce: 0.166684
iteration 255: loss: 0.677700, mix_dice: 1.188021, mix_ce: 0.167379
iteration 255: loss: 0.677700, mix_dice: 1.188021, mix_ce: 0.167379
iteration 255: loss: 0.677700, mix_dice: 1.188021, mix_ce: 0.167379
iteration 255: loss: 0.677700, mix_dice: 1.188021, mix_ce: 0.167379
iteration 255: loss: 0.677700, mix_dice: 1.188021, mix_ce: 0.167379
iteration 255: loss: 0.677700, mix_dice: 1.188021, mix_ce: 0.167379
iteration 255: loss: 0.677700, mix_dice: 1.188021, mix_ce: 0.167379
iteration 256: loss: 0.532699, mix_dice: 0.93239

 65%|█████████████████████▍           | 24/37 [01:56<01:03,  4.91s/it]

iteration 265: loss: 0.719557, mix_dice: 1.304100, mix_ce: 0.135015
iteration 265: loss: 0.719557, mix_dice: 1.304100, mix_ce: 0.135015
iteration 265: loss: 0.719557, mix_dice: 1.304100, mix_ce: 0.135015
iteration 265: loss: 0.719557, mix_dice: 1.304100, mix_ce: 0.135015
iteration 265: loss: 0.719557, mix_dice: 1.304100, mix_ce: 0.135015
iteration 265: loss: 0.719557, mix_dice: 1.304100, mix_ce: 0.135015
iteration 265: loss: 0.719557, mix_dice: 1.304100, mix_ce: 0.135015
iteration 266: loss: 0.808382, mix_dice: 1.462335, mix_ce: 0.154429
iteration 266: loss: 0.808382, mix_dice: 1.462335, mix_ce: 0.154429
iteration 266: loss: 0.808382, mix_dice: 1.462335, mix_ce: 0.154429
iteration 266: loss: 0.808382, mix_dice: 1.462335, mix_ce: 0.154429
iteration 266: loss: 0.808382, mix_dice: 1.462335, mix_ce: 0.154429
iteration 266: loss: 0.808382, mix_dice: 1.462335, mix_ce: 0.154429
iteration 266: loss: 0.808382, mix_dice: 1.462335, mix_ce: 0.154429
iteration 267: loss: 0.582394, mix_dice: 1.04790

 68%|██████████████████████▎          | 25/37 [02:00<00:57,  4.78s/it]

iteration 276: loss: 0.594413, mix_dice: 1.028867, mix_ce: 0.159958
iteration 276: loss: 0.594413, mix_dice: 1.028867, mix_ce: 0.159958
iteration 276: loss: 0.594413, mix_dice: 1.028867, mix_ce: 0.159958
iteration 276: loss: 0.594413, mix_dice: 1.028867, mix_ce: 0.159958
iteration 276: loss: 0.594413, mix_dice: 1.028867, mix_ce: 0.159958
iteration 276: loss: 0.594413, mix_dice: 1.028867, mix_ce: 0.159958
iteration 276: loss: 0.594413, mix_dice: 1.028867, mix_ce: 0.159958
iteration 277: loss: 0.768735, mix_dice: 1.357621, mix_ce: 0.179850
iteration 277: loss: 0.768735, mix_dice: 1.357621, mix_ce: 0.179850
iteration 277: loss: 0.768735, mix_dice: 1.357621, mix_ce: 0.179850
iteration 277: loss: 0.768735, mix_dice: 1.357621, mix_ce: 0.179850
iteration 277: loss: 0.768735, mix_dice: 1.357621, mix_ce: 0.179850
iteration 277: loss: 0.768735, mix_dice: 1.357621, mix_ce: 0.179850
iteration 277: loss: 0.768735, mix_dice: 1.357621, mix_ce: 0.179850
iteration 278: loss: 0.568999, mix_dice: 1.00377

 70%|███████████████████████▏         | 26/37 [02:05<00:51,  4.71s/it]

iteration 287: loss: 0.624370, mix_dice: 1.063510, mix_ce: 0.185230
iteration 287: loss: 0.624370, mix_dice: 1.063510, mix_ce: 0.185230
iteration 287: loss: 0.624370, mix_dice: 1.063510, mix_ce: 0.185230
iteration 287: loss: 0.624370, mix_dice: 1.063510, mix_ce: 0.185230
iteration 287: loss: 0.624370, mix_dice: 1.063510, mix_ce: 0.185230
iteration 287: loss: 0.624370, mix_dice: 1.063510, mix_ce: 0.185230
iteration 287: loss: 0.624370, mix_dice: 1.063510, mix_ce: 0.185230
iteration 288: loss: 0.737741, mix_dice: 1.347007, mix_ce: 0.128476
iteration 288: loss: 0.737741, mix_dice: 1.347007, mix_ce: 0.128476
iteration 288: loss: 0.737741, mix_dice: 1.347007, mix_ce: 0.128476
iteration 288: loss: 0.737741, mix_dice: 1.347007, mix_ce: 0.128476
iteration 288: loss: 0.737741, mix_dice: 1.347007, mix_ce: 0.128476
iteration 288: loss: 0.737741, mix_dice: 1.347007, mix_ce: 0.128476
iteration 288: loss: 0.737741, mix_dice: 1.347007, mix_ce: 0.128476
iteration 289: loss: 0.706570, mix_dice: 1.29175

 73%|████████████████████████         | 27/37 [02:10<00:46,  4.66s/it]

iteration 298: loss: 0.713367, mix_dice: 1.284214, mix_ce: 0.142521
iteration 298: loss: 0.713367, mix_dice: 1.284214, mix_ce: 0.142521
iteration 298: loss: 0.713367, mix_dice: 1.284214, mix_ce: 0.142521
iteration 298: loss: 0.713367, mix_dice: 1.284214, mix_ce: 0.142521
iteration 298: loss: 0.713367, mix_dice: 1.284214, mix_ce: 0.142521
iteration 298: loss: 0.713367, mix_dice: 1.284214, mix_ce: 0.142521
iteration 298: loss: 0.713367, mix_dice: 1.284214, mix_ce: 0.142521
iteration 299: loss: 0.720479, mix_dice: 1.337677, mix_ce: 0.103282
iteration 299: loss: 0.720479, mix_dice: 1.337677, mix_ce: 0.103282
iteration 299: loss: 0.720479, mix_dice: 1.337677, mix_ce: 0.103282
iteration 299: loss: 0.720479, mix_dice: 1.337677, mix_ce: 0.103282
iteration 299: loss: 0.720479, mix_dice: 1.337677, mix_ce: 0.103282
iteration 299: loss: 0.720479, mix_dice: 1.337677, mix_ce: 0.103282
iteration 299: loss: 0.720479, mix_dice: 1.337677, mix_ce: 0.103282
iteration 300: loss: 0.728929, mix_dice: 1.30345

 76%|████████████████████████▉        | 28/37 [02:14<00:41,  4.64s/it]

iteration 309: loss: 0.752826, mix_dice: 1.375992, mix_ce: 0.129661
iteration 309: loss: 0.752826, mix_dice: 1.375992, mix_ce: 0.129661
iteration 309: loss: 0.752826, mix_dice: 1.375992, mix_ce: 0.129661
iteration 309: loss: 0.752826, mix_dice: 1.375992, mix_ce: 0.129661
iteration 309: loss: 0.752826, mix_dice: 1.375992, mix_ce: 0.129661
iteration 309: loss: 0.752826, mix_dice: 1.375992, mix_ce: 0.129661
iteration 309: loss: 0.752826, mix_dice: 1.375992, mix_ce: 0.129661
iteration 310: loss: 0.711660, mix_dice: 1.240242, mix_ce: 0.183078
iteration 310: loss: 0.711660, mix_dice: 1.240242, mix_ce: 0.183078
iteration 310: loss: 0.711660, mix_dice: 1.240242, mix_ce: 0.183078
iteration 310: loss: 0.711660, mix_dice: 1.240242, mix_ce: 0.183078
iteration 310: loss: 0.711660, mix_dice: 1.240242, mix_ce: 0.183078
iteration 310: loss: 0.711660, mix_dice: 1.240242, mix_ce: 0.183078
iteration 310: loss: 0.711660, mix_dice: 1.240242, mix_ce: 0.183078
iteration 311: loss: 0.735326, mix_dice: 1.34862

 78%|█████████████████████████▊       | 29/37 [02:19<00:36,  4.61s/it]

iteration 320: loss: 0.456274, mix_dice: 0.749560, mix_ce: 0.162987
iteration 320: loss: 0.456274, mix_dice: 0.749560, mix_ce: 0.162987
iteration 320: loss: 0.456274, mix_dice: 0.749560, mix_ce: 0.162987
iteration 320: loss: 0.456274, mix_dice: 0.749560, mix_ce: 0.162987
iteration 320: loss: 0.456274, mix_dice: 0.749560, mix_ce: 0.162987
iteration 320: loss: 0.456274, mix_dice: 0.749560, mix_ce: 0.162987
iteration 320: loss: 0.456274, mix_dice: 0.749560, mix_ce: 0.162987
iteration 321: loss: 0.702584, mix_dice: 1.283707, mix_ce: 0.121462
iteration 321: loss: 0.702584, mix_dice: 1.283707, mix_ce: 0.121462
iteration 321: loss: 0.702584, mix_dice: 1.283707, mix_ce: 0.121462
iteration 321: loss: 0.702584, mix_dice: 1.283707, mix_ce: 0.121462
iteration 321: loss: 0.702584, mix_dice: 1.283707, mix_ce: 0.121462
iteration 321: loss: 0.702584, mix_dice: 1.283707, mix_ce: 0.121462
iteration 321: loss: 0.702584, mix_dice: 1.283707, mix_ce: 0.121462
iteration 322: loss: 0.676100, mix_dice: 1.21348

 81%|██████████████████████████▊      | 30/37 [02:23<00:32,  4.60s/it]

iteration 331: loss: 0.814279, mix_dice: 1.447794, mix_ce: 0.180764
iteration 331: loss: 0.814279, mix_dice: 1.447794, mix_ce: 0.180764
iteration 331: loss: 0.814279, mix_dice: 1.447794, mix_ce: 0.180764
iteration 331: loss: 0.814279, mix_dice: 1.447794, mix_ce: 0.180764
iteration 331: loss: 0.814279, mix_dice: 1.447794, mix_ce: 0.180764
iteration 331: loss: 0.814279, mix_dice: 1.447794, mix_ce: 0.180764
iteration 331: loss: 0.814279, mix_dice: 1.447794, mix_ce: 0.180764
iteration 332: loss: 0.700325, mix_dice: 1.281860, mix_ce: 0.118789
iteration 332: loss: 0.700325, mix_dice: 1.281860, mix_ce: 0.118789
iteration 332: loss: 0.700325, mix_dice: 1.281860, mix_ce: 0.118789
iteration 332: loss: 0.700325, mix_dice: 1.281860, mix_ce: 0.118789
iteration 332: loss: 0.700325, mix_dice: 1.281860, mix_ce: 0.118789
iteration 332: loss: 0.700325, mix_dice: 1.281860, mix_ce: 0.118789
iteration 332: loss: 0.700325, mix_dice: 1.281860, mix_ce: 0.118789
iteration 333: loss: 0.613600, mix_dice: 1.10977

 84%|███████████████████████████▋     | 31/37 [02:28<00:27,  4.57s/it]

iteration 342: loss: 0.658071, mix_dice: 1.166049, mix_ce: 0.150092
iteration 342: loss: 0.658071, mix_dice: 1.166049, mix_ce: 0.150092
iteration 342: loss: 0.658071, mix_dice: 1.166049, mix_ce: 0.150092
iteration 342: loss: 0.658071, mix_dice: 1.166049, mix_ce: 0.150092
iteration 342: loss: 0.658071, mix_dice: 1.166049, mix_ce: 0.150092
iteration 342: loss: 0.658071, mix_dice: 1.166049, mix_ce: 0.150092
iteration 342: loss: 0.658071, mix_dice: 1.166049, mix_ce: 0.150092
iteration 343: loss: 0.696459, mix_dice: 1.275962, mix_ce: 0.116957
iteration 343: loss: 0.696459, mix_dice: 1.275962, mix_ce: 0.116957
iteration 343: loss: 0.696459, mix_dice: 1.275962, mix_ce: 0.116957
iteration 343: loss: 0.696459, mix_dice: 1.275962, mix_ce: 0.116957
iteration 343: loss: 0.696459, mix_dice: 1.275962, mix_ce: 0.116957
iteration 343: loss: 0.696459, mix_dice: 1.275962, mix_ce: 0.116957
iteration 343: loss: 0.696459, mix_dice: 1.275962, mix_ce: 0.116957
iteration 344: loss: 0.608602, mix_dice: 1.07910

 86%|████████████████████████████▌    | 32/37 [02:32<00:22,  4.58s/it]

iteration 353: loss: 0.729518, mix_dice: 1.337813, mix_ce: 0.121223
iteration 353: loss: 0.729518, mix_dice: 1.337813, mix_ce: 0.121223
iteration 353: loss: 0.729518, mix_dice: 1.337813, mix_ce: 0.121223
iteration 353: loss: 0.729518, mix_dice: 1.337813, mix_ce: 0.121223
iteration 353: loss: 0.729518, mix_dice: 1.337813, mix_ce: 0.121223
iteration 353: loss: 0.729518, mix_dice: 1.337813, mix_ce: 0.121223
iteration 353: loss: 0.729518, mix_dice: 1.337813, mix_ce: 0.121223
iteration 354: loss: 0.742878, mix_dice: 1.285244, mix_ce: 0.200511
iteration 354: loss: 0.742878, mix_dice: 1.285244, mix_ce: 0.200511
iteration 354: loss: 0.742878, mix_dice: 1.285244, mix_ce: 0.200511
iteration 354: loss: 0.742878, mix_dice: 1.285244, mix_ce: 0.200511
iteration 354: loss: 0.742878, mix_dice: 1.285244, mix_ce: 0.200511
iteration 354: loss: 0.742878, mix_dice: 1.285244, mix_ce: 0.200511
iteration 354: loss: 0.742878, mix_dice: 1.285244, mix_ce: 0.200511
iteration 355: loss: 0.718300, mix_dice: 1.31857

 89%|█████████████████████████████▍   | 33/37 [02:37<00:18,  4.53s/it]

iteration 364: loss: 0.633207, mix_dice: 1.158733, mix_ce: 0.107682
iteration 364: loss: 0.633207, mix_dice: 1.158733, mix_ce: 0.107682
iteration 364: loss: 0.633207, mix_dice: 1.158733, mix_ce: 0.107682
iteration 364: loss: 0.633207, mix_dice: 1.158733, mix_ce: 0.107682
iteration 364: loss: 0.633207, mix_dice: 1.158733, mix_ce: 0.107682
iteration 364: loss: 0.633207, mix_dice: 1.158733, mix_ce: 0.107682
iteration 364: loss: 0.633207, mix_dice: 1.158733, mix_ce: 0.107682
iteration 365: loss: 0.697137, mix_dice: 1.297130, mix_ce: 0.097144
iteration 365: loss: 0.697137, mix_dice: 1.297130, mix_ce: 0.097144
iteration 365: loss: 0.697137, mix_dice: 1.297130, mix_ce: 0.097144
iteration 365: loss: 0.697137, mix_dice: 1.297130, mix_ce: 0.097144
iteration 365: loss: 0.697137, mix_dice: 1.297130, mix_ce: 0.097144
iteration 365: loss: 0.697137, mix_dice: 1.297130, mix_ce: 0.097144
iteration 365: loss: 0.697137, mix_dice: 1.297130, mix_ce: 0.097144
iteration 366: loss: 0.589987, mix_dice: 1.06186

 92%|██████████████████████████████▎  | 34/37 [02:41<00:13,  4.49s/it]

iteration 375: loss: 0.621362, mix_dice: 1.089766, mix_ce: 0.152957
iteration 375: loss: 0.621362, mix_dice: 1.089766, mix_ce: 0.152957
iteration 375: loss: 0.621362, mix_dice: 1.089766, mix_ce: 0.152957
iteration 375: loss: 0.621362, mix_dice: 1.089766, mix_ce: 0.152957
iteration 375: loss: 0.621362, mix_dice: 1.089766, mix_ce: 0.152957
iteration 375: loss: 0.621362, mix_dice: 1.089766, mix_ce: 0.152957
iteration 375: loss: 0.621362, mix_dice: 1.089766, mix_ce: 0.152957
iteration 376: loss: 0.707512, mix_dice: 1.301372, mix_ce: 0.113652
iteration 376: loss: 0.707512, mix_dice: 1.301372, mix_ce: 0.113652
iteration 376: loss: 0.707512, mix_dice: 1.301372, mix_ce: 0.113652
iteration 376: loss: 0.707512, mix_dice: 1.301372, mix_ce: 0.113652
iteration 376: loss: 0.707512, mix_dice: 1.301372, mix_ce: 0.113652
iteration 376: loss: 0.707512, mix_dice: 1.301372, mix_ce: 0.113652
iteration 376: loss: 0.707512, mix_dice: 1.301372, mix_ce: 0.113652
iteration 377: loss: 0.675753, mix_dice: 1.25756

 95%|███████████████████████████████▏ | 35/37 [02:46<00:09,  4.50s/it]

iteration 386: loss: 0.695354, mix_dice: 1.291930, mix_ce: 0.098778
iteration 386: loss: 0.695354, mix_dice: 1.291930, mix_ce: 0.098778
iteration 386: loss: 0.695354, mix_dice: 1.291930, mix_ce: 0.098778
iteration 386: loss: 0.695354, mix_dice: 1.291930, mix_ce: 0.098778
iteration 386: loss: 0.695354, mix_dice: 1.291930, mix_ce: 0.098778
iteration 386: loss: 0.695354, mix_dice: 1.291930, mix_ce: 0.098778
iteration 386: loss: 0.695354, mix_dice: 1.291930, mix_ce: 0.098778
iteration 387: loss: 0.684405, mix_dice: 1.254637, mix_ce: 0.114173
iteration 387: loss: 0.684405, mix_dice: 1.254637, mix_ce: 0.114173
iteration 387: loss: 0.684405, mix_dice: 1.254637, mix_ce: 0.114173
iteration 387: loss: 0.684405, mix_dice: 1.254637, mix_ce: 0.114173
iteration 387: loss: 0.684405, mix_dice: 1.254637, mix_ce: 0.114173
iteration 387: loss: 0.684405, mix_dice: 1.254637, mix_ce: 0.114173
iteration 387: loss: 0.684405, mix_dice: 1.254637, mix_ce: 0.114173
iteration 388: loss: 0.713843, mix_dice: 1.29886

 97%|████████████████████████████████ | 36/37 [02:50<00:04,  4.50s/it]

iteration 397: loss: 0.695511, mix_dice: 1.283136, mix_ce: 0.107887
iteration 397: loss: 0.695511, mix_dice: 1.283136, mix_ce: 0.107887
iteration 397: loss: 0.695511, mix_dice: 1.283136, mix_ce: 0.107887
iteration 397: loss: 0.695511, mix_dice: 1.283136, mix_ce: 0.107887
iteration 397: loss: 0.695511, mix_dice: 1.283136, mix_ce: 0.107887
iteration 397: loss: 0.695511, mix_dice: 1.283136, mix_ce: 0.107887
iteration 397: loss: 0.695511, mix_dice: 1.283136, mix_ce: 0.107887
iteration 398: loss: 0.645210, mix_dice: 1.101206, mix_ce: 0.189215
iteration 398: loss: 0.645210, mix_dice: 1.101206, mix_ce: 0.189215
iteration 398: loss: 0.645210, mix_dice: 1.101206, mix_ce: 0.189215
iteration 398: loss: 0.645210, mix_dice: 1.101206, mix_ce: 0.189215
iteration 398: loss: 0.645210, mix_dice: 1.101206, mix_ce: 0.189215
iteration 398: loss: 0.645210, mix_dice: 1.101206, mix_ce: 0.189215
iteration 398: loss: 0.645210, mix_dice: 1.101206, mix_ce: 0.189215
iteration 399: loss: 0.699340, mix_dice: 1.26470

 97%|████████████████████████████████ | 36/37 [03:00<00:05,  5.02s/it]


In [166]:
# Params 
class params: 
    def __init__(self): 
        self.root_dir = '/kaggle/input/acdc-dataset/ACDC' 
        self.exp = 'BCP' 
        self.model = 'unet' 
        self.pretrain_iterations = 400 
        
        self.selftrain_iterations = 400
        self.batch_size = 24
        self.deterministic = 1 # What the fucck here
        self.base_lr = 0.01 
        self.patch_size = [256,256] 
        self.seed = 42 
        self.num_classes = 4 
        self.stage_name = 'selftrain'

        # label and unlabel 
        self.labeled_bs = 12
        self.label_num = 7 
        self.u_weight = 0.5 

        # Cost 
        self.gpu = '0' 
        self.consistency = 0.1
        self.consistency_rampup = 200.0 
        self.magnitude = '6.0' 
        self.s_param = 6 

args = params()

### --------------------------------------Test self-train-------------------------------

In [172]:
import SimpleITK as sitk
from medpy import metric

In [173]:
def calculate_metric_percase(pred, gt):
    pred[pred > 0] = 1
    gt[gt > 0] = 1
    dice = metric.binary.dc(pred, gt)
    jc = metric.binary.jc(pred, gt)
    asd = metric.binary.asd(pred, gt)
    hd95 = metric.binary.hd95(pred, gt)
    
    return dice, jc, asd, hd95

In [174]:
def test_single_volume(case, net, test_save_path, FLAGS):
    h5f = h5py.File(FLAGS.root_dir + "/data/{}.h5".format(case), 'r')
    image = h5f['image'][:]
    label = h5f['label'][:]
    
    prediction = np.zeros_like(label)
    for ind in range(image.shape[0]):
        slice = image[ind, :, :]
        x, y = slice.shape[0], slice.shape[1]
        slice = zoom(slice, (256 / x, 256 / y), order=0)
        input = torch.from_numpy(slice).unsqueeze(0).unsqueeze(0).float().cuda()
        net.eval()
        with torch.no_grad():
            out_main = net(input)
            if len(out_main) > 1:
                out_main = out_main[0]
            out = torch.argmax(torch.softmax(out_main, dim=1), dim=1).squeeze(0)
            out = out.cpu().detach().numpy()
            pred = zoom(out, (x / 256, y / 256), order=0)
            prediction[ind] = pred
    if np.sum(prediction == 1) == 0:
        first_metric = 0, 0, 0, 0
    else:
        first_metric = calculate_metric_percase(prediction == 1, label == 1)
        
    if np.sum(prediction == 2) == 0:
        second_metric = 0, 0, 0, 0
    else:
        second_metric = calculate_metric_percase(prediction == 2, label == 2)
        
    if np.sum(prediction == 3) == 0:
        third_metric = 0, 0, 0, 0
    else:
        third_metric = calculate_metric_percase(prediction == 2, label == 2)
        
    img_itk  = sitk.GetImageFromArray(image.astype(np.float32))
    img_itk.SetSpacing((1, 1, 10))
    pred_itk  = sitk.GetImageFromArray(prediction.astype(np.float32))
    pred_itk.SetSpacing((1, 1, 10))
    label_itk  = sitk.GetImageFromArray(label.astype(np.float32))
    label_itk.SetSpacing((1, 1, 10))
    
    return first_metric, second_metric, third_metric

In [175]:
def Inference(FLAGS):
    with open(FLAGS.root_dir + '/test.list', 'r') as f:
        image_list = f.readlines()
        
    image_list = sorted([item.replace('\n', '').split(".")[0] for item in image_list])
    snapshot_path = "./model/BCP/ACDC_{}_{}_labeled/{}".format(FLAGS.exp, FLAGS.label_num, FLAGS.stage_name)
    test_save_path = "./model/BCP/ACDC_{}_{}_labeled/{}_predictions/".format(FLAGS.exp, FLAGS.label_num, FLAGS.model)
    # if not os.path.exists(test_save_path):
    #     shutil.rmtree(test_save_path)
    # os.makedirs(test_save_path)
    net = net_factory(net_type=FLAGS.model, in_chns=1, class_num=FLAGS.num_classes)
    save_model_path = os.path.join(snapshot_path, '{}_best_model.pth'.format(FLAGS.model))
    net.load_state_dict(torch.load(save_model_path))
    
    print("init weight from {}".format(save_model_path))
    net.eval()
    
    first_total = 0.0
    second_total = 0.0 
    third_total = 0.0
    for case in tqdm(image_list):
        first_metric, second_metric, third_metric = test_single_volume(case, net, test_save_path, FLAGS)
        first_total += np.asarray(first_metric)
        second_total += np.asarray(second_metric)
        third_total += np.asarray(third_metric)
    avg_metric = [first_total / len(image_list), second_total / len(image_list), third_total / len(image_list)]
    return avg_metric, test_save_path

In [176]:
FLAGS = args
metric, test_save_path = Inference(FLAGS)
print(metric)
print((metric[0]+metric[1]+metric[2])/3)
with open(test_save_path+'../performance.txt', 'w') as f:
    f.writelines('metric is {} \n'.format(metric))
    f.writelines('average metric is {}\n'.format((metric[0]+metric[1]+metric[2])/3))

  net.load_state_dict(torch.load(save_model_path))


init weight from ./model/BCP/ACDC_BCP_7_labeled/selftrain/unet_best_model.pth


100%|██████████| 40/40 [00:19<00:00,  2.09it/s]

[array([0.67397536, 0.53786481, 0.7783597 , 5.014602  ]), array([0.63967933, 0.48177404, 0.66839407, 3.01806013]), array([0.63967933, 0.48177404, 0.66839407, 3.01806013])]
[0.65111134 0.50047096 0.70504928 3.68357409]



