In [17]:
import numpy as np 
import h5py
import os 
import random 

import torch 
import torch.nn as nn 
import torch.nn.functional as F 
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from skimage.measure import label 

import torch.backends.cudnn as cudnn

In [19]:
# Params 
class params: 
    def __init__(self):
        self.root_path = 'LA'
        self.exp = 'BCP' 
        self.model = 'Unet'
        self.pre_max_iterations = 200
        self.self_max_iteration = 100 
        self.max_samples = 80 
        self.labeles_bs = 4
        self.bacth_size = 8 
        self.base_lr = 0.01 
        self.deterministic = 1 
        self.labelnum = 8 
        self.consistency = 1.0 
        self.consistency_rampup = 40.0 
        self.magnitude = 10.0 
        self.seed = 10
    
        # Setting of BCP 
        self.u_weight = 0.5 
        self.mask_ratio = 2/3 

        # Setting of mixup 
        self.u_alpha = 2.0 
        self.loss_weight = 0.5


args = params()

#### 1. BaseDataset

In [20]:
import os
import h5py
from torch.utils.data import Dataset

class LAHeart(Dataset):
    def __init__(self, base_dir, split='train', transform=None, num=None):
        self._base_dir = base_dir
        self.split = split
        self.transform = transform
        self.sample_list = []
        
        # Path for train/test list
        list_file = os.path.join(self._base_dir, f"{split}.list")
        if not os.path.isfile(list_file):
            raise ValueError(f"The {split} list file is missing: {list_file}")
        
        with open(list_file, 'r') as file:
            self.sample_list = [item.strip() for item in file.readlines()]
        
        if num is not None:
            self.sample_list = self.sample_list[:num]

        print(f"Mode = {self.split}, total samples: {len(self.sample_list)}")

    def __len__(self):
        return len(self.sample_list)

    def __getitem__(self, index):
        case = self.sample_list[index]
        file_path = os.path.join(self._base_dir, f'2018LA_Seg_Training Set/{case}/mri_norm2.h5')
        
        # Load data safely
        try:
            with h5py.File(file_path, 'r') as h5f:
                image = h5f['image'][:]
                label = h5f['label'][:]
        except FileNotFoundError:
            raise FileNotFoundError(f"File not found: {file_path}")
        
        sample = {'image': image, 'label': label}
        if self.transform:
            sample = self.transform(sample)
        
        return sample


In [21]:
train_db = LAHeart(
    base_dir= 'LA', 
    split= 'train'
)

sample =train_db.__getitem__(10)
print(f'Image.shape = {sample['image'].shape}')
print(f'Label.shape = {sample['label'].shape}')

Mode = train, total samples: 80
Image.shape = (203, 142, 88)
Label.shape = (203, 142, 88)


In [22]:
def random_rot_flip(image, label): 
    k = np.random.randint(0, 4, 1) 
    image = np.rot90(image, k) 
    label = np.rot90(label, k) 

    axis = np.random.randint(0, 2)
    image = np.flip(image, axis) 
    label = np.flip(label, axis) 

    return image, label 

class RandomRotFlip: 
    def __call__(self, sample): 
        image, label = sample['image'], sample['label']
        image, label = random_rot_flip(image, label) 
        sample = {'image': image, 'label': label}

        return sample 

In [23]:
class RandomCrop: 
    def __init__(self, output_size, with_sdf= False): 
        self.output_size = output_size 
        self.with_sdf = with_sdf
    
    def __call__(self, sample): 
        image, label = sample['image'], sample['label']

        if self.with_sdf: 
            sdf = sample['sdf']
        
        if label.shape[0] <= self.output_size[0] or label.shape[1] <= self.output_size[1] or label.shape[2] <= self.output_size[2]: 
            pw = max((self.output_size[0] - label.shape[0]) // 2 + 3, 0) 
            ph = max((self.output_size[1] - label.shape[1]) // 2 + 3, 0)
            pd = max((self.output_size[2] - label.shape[2]) // 2 + 3, 0)
            image = np.pad(image, [(pw, pw), (ph, ph),(pd, pd)], mode= 'constant', constant_values= 0)
            label = np.pad(label, [(pw, pw), (ph, ph), (pd, pd)], mode= 'constant', constant_values= 0) 

            if self.with_sdf: 
                sdf = np.pad(sdf, [(pw, pw), (ph, ph), (pd, pd)], mode= 'constant', constant_values= 0) 

        (w, h,d) = image.shape 
        w1 = np.random.randint(0, w - self.output_size[0])
        h1 = np.random.randint(0, h - self.output_size[1])
        d1 = np.random.randint(0, d - self.output_size[2])
    
        image = image[w1 : w1 + self.output_size[0], h1 : h1 + self.output_size[1], d1 : d1 + self.output_size[2]]
        label = label[w1: w1 + self.output_size[0], h1 : h1 + self.output_size[1], d1 : d1 + self.output_size[2]]

        if self.with_sdf: 
            sdf = sdf[w1 : w1 + self.output_size[0], h1 : h1 + self.output_size[1], d1 : d1 + self.output_size[2]]
            return {'image': image, 'label': label, 'sdf': sdf}
        else: 
            return {'image': image, 'label': label}


#### 2.Loss

In [24]:
def to_one_hot(tensor, nclasses): 
    """
    Input (tensor): Nx1xHxW 
    """
    assert tensor.max().item() < nclasses
    assert tensor.min().item() >= 0 

    size = list(tensor.size())
    assert size[1] == 1 
    size[1] = nclasses
    one_hot = torch.zeros(*size) 
    if tensor.is_cuda: 
        one_hot = one_hot.cuda(tensor.device) 
    one_hot = one_hot.scatter_(1, tensor, 1) 
    return one_hot 

def get_probability(logits): 
    """
    Get the probability from logitis  
    """
    size = logits.size() 
    if size[1] > 1: 
        pred = F.softmax(logits, dim= 1) 
        nclass = size[1] 
    else: 
        pred = F.sigmoid(logits) 
        pred = torch.cat([1 - pred, pred], dim= 1) 
    
    return pred, nclass


class mask_DiceLoss(nn.Module): 
    def __init__(self, nclass, class_weights = None, smooth= 1e-5): 
        super(mask_DiceLoss, self).__init__() 
        self.smooth = smooth 
        if class_weights is None: 
            self.class_weights = nn.Parameter(torch.ones((1, nclass)).type(torch.float32), requires_grad= False) 
        else: 
            class_weights = np.array(class_weights) 
            assert nclass == class_weights.shape[0] 
            self.class_weights = nn.Parameter(torch.tensor(class_weights, dtype= torch.float32), requires_grad= False) 
    
    def prob_forward(self, pred, target, mask= None): 
        size = pred.size() 
        N, nclass = size[0], size[1] 

        # N x C x H x W 
        pred_one_hot = pred.view(N, nclass, -1) 
        target = target.view(N, 1, -1) 
        target_one_hot = to_one_hot(target.type(torch.long), nclass).type(torch.float32)

        # N x C x H x W 
        inter = pred_one_hot * target_one_hot
        union = pred_one_hot + target_one_hot

        if mask is not None: 
            mask = mask.view(N, 1, -1) 
            inter = (inter.view(N, nclass, -1) * mask).sum(2) 
            union = (union.view(N, nclass, -1) * mask).sum(2) 
        else: 
            inter = inter.view(N, nclass, -1).sum(2) 
            union = union.view(N, nclass, -1).sum(2)
        
        dice = ( 2*inter + self.smooth ) / (union + self.smooth) 
        return 1 - dice.mean()

    def forward(self, logits,target, mask = None): 
        size = logits.size() 
        N, nclass = size[0], size[1] 

        logits = logits.view(N, nclass, -1) 
        target = target.view(N, 1, -1) 

        pred,nclass = get_probability(logits) 

        pred_one_hot = pred 
        target_one_hot = to_one_hot(target.type(torch.long), nclass).type(torch.float32) 

        inter = pred_one_hot * target_one_hot
        union = pred_one_hot + target_one_hot

        if mask is not None: 
            mask = mask.view(N, 1, -1) 
            inter = (inter.view(N, nclass, -1) * mask).sum(2)
            union = (union.view(N, nclass, -1) * mask ).sum(2) 
        else: 
            inter = inter.view(N, nclass, -1).sum(2) 
            union = union.view(N, nclass, -1).sum(2)
        
        dice = ( 2 * inter + self.smooth ) / (union + self.smooth)
        return 1 - dice.mean() 
        


#### UNet3D

In [25]:
from torch.nn import Module, Sequential
from torch.nn import Conv3d, ConvTranspose3d, BatchNorm3d, MaxPool3d, AvgPool3d, AvgPool1d, Dropout3d
from torch.nn import ReLU, Sigmoid
import torch
import pdb

class UNet(Module):
    def __init__(self, in_dimension=1, out_dimension=2, ft_channels=[64, 256, 256, 512, 1024], residual='conv'):
        super(UNet, self).__init__()
        
        # Encoder downsamplers
        self.pool1 = MaxPool3d((2, 2, 2))
        self.pool2 = MaxPool3d((2, 2, 2))
        self.pool3 = MaxPool3d((2, 2, 2))
        self.pool4 = MaxPool3d((2, 2, 2))
        
        # Encoder convolutions
        self.conv_block1 = Conv3D_Block(in_dimension, ft_channels[0], residual=residual)
        self.conv_block2 = Conv3D_Block(ft_channels[0], ft_channels[1], residual=residual)
        self.conv_block3 = Conv3D_Block(ft_channels[1], ft_channels[2], residual=residual)
        self.conv_block4 = Conv3D_Block(ft_channels[2], ft_channels[3], residual=residual)
        self.conv_block5 = Conv3D_Block(ft_channels[3], ft_channels[4], residual=residual)
        
        # Decoderr convolutions
        self.decoder_conv_block4 = Conv3D_Block(2 * ft_channels[3], ft_channels[3], residual=residual)
        self.decoder_conv_block3 = Conv3D_Block(2 * ft_channels[2], ft_channels[2], residual=residual)
        self.decoder_conv_block2 = Conv3D_Block(2 * ft_channels[1], ft_channels[1], residual=residual)
        self.decoder_conv_block1 = Conv3D_Block(2 * ft_channels[0], ft_channels[0], residual=residual)
        
        # Decoder upsamplers
        self.deconv_block4 = Deconv3D_Block(ft_channels[4], ft_channels[3])
        self.deconv_block3 = Deconv3D_Block(ft_channels[3], ft_channels[2])
        self.deconv_block2 = Deconv3D_Block(ft_channels[2], ft_channels[1])
        self.deconv_block1 = Deconv3D_Block(ft_channels[1], ft_channels[0])
        
        # Final 1*1 Convolutions segmentation map
        self.one_conv = Conv3d(ft_channels[0], out_dimension, kernel_size=1, stride=1, padding=0, bias=True)
        
        # Activation function
        self.sigmoid = Sigmoid()
        
    def forward(self, x):
        
        # Encoder part
        x1 = self.conv_block1(x)
        x_low1 = self.pool1(x1)
        
        x2 = self.conv_block2(x_low1)
        x_low2 = self.pool2(x2)
        
        x3 = self.conv_block3(x_low2)
        x_low3 = self.pool3(x3)
        
        x4 = self.conv_block4(x_low3)
        x_low4 = self.pool4(x4)
        
        base = self.conv_block5(x_low4)
        
        # Decoder part
        d4 = torch.cat([self.deconv_block4(base), x4], dim=1)
        d_high4 = self.decoder_conv_block4(d4)
        
        d3 = torch.cat([self.deconv_block3(d_high4), x3], dim=1)
        d_high3 = self.decoder_conv_block3(d3)
        d_high3 = Dropout3d(p=0.05)(d_high3)
        
        d2 = torch.cat([self.deconv_block2(d_high3), x2], dim=1)
        d_high2 = self.decoder_conv_block2(d2)
        d_high2 = Dropout3d(p=0.05)(d_high2)
        
        d1 = torch.cat([self.deconv_block1(d_high2), x1], dim=1)
        d_high1 = self.decoder_conv_block1(d1)
        
        seg = self.one_conv(d_high1)
        
        return seg

        
class Conv3D_Block(Module):
    def __init__(self, in_features, out_features, kernel=3, stride=1, padding=1, residual=None):
        super(Conv3D_Block, self).__init__()
        
        self.conv1 = Sequential(
            Conv3d(in_features, out_features, kernel_size=kernel, stride=stride, padding=padding, bias=True),
            BatchNorm3d(out_features),
            ReLU()
        )
        
        self.conv2 = Sequential(
            Conv3d(out_features, out_features, kernel_size=kernel, stride=stride, padding=padding, bias=True),
            BatchNorm3d(out_features),
            ReLU()
        )
        
        self.residual = residual
        
        if self.residual is not None:
            self.residual_upsampler = Conv3d(in_features, out_features, kernel_size=1, bias=False)
            
    def forward(self, x):
        
        res = x
        
        if not self.residual:
            return self.conv2(self.conv1(x))
        else:
            return self.conv2(self.conv1(x)) + self.residual_upsampler(res)
        
class Deconv3D_Block(Module):
    
    def __init__(self, in_features, out_features, kernel=3, stride=2, padding=1):
        super(Deconv3D_Block, self).__init__()
        
        self.deconv = Sequential(
            ConvTranspose3d(in_features, out_features, kernel_size=(kernel, kernel, kernel),
                            stride=(stride, stride, stride), padding=(padding, padding, padding), output_padding=1, bias=True),
            ReLU()
        )
        
    def forward(self, x):
        return self.deconv(x)

In [26]:
def sigmoid_rampup(current, rampup_length):
    if rampup_length == 0: 
        return 1.0 
    else:
        current = np.clip(current, 0, rampup_length)
        phase = 1 - (current / rampup_length)
        return float(np.exp(-5 * phase * phase))
    
# Mean-Teacher compomnent 
def get_current_consistency_weight(epoch, args): 
    return 5 * args.consistency + sigmoid_rampup(epoch, args.consistency_rampup)

def update_model_ema(model, ema_model, alpha): 
    model_state = model.state_dict() 
    model_ema_state = ema_model.state_dict()


    new_dict = {} 

    for key in model_state:
        new_dict[key] = alpha * model_ema_state[key] + (1 - alpha) * model_state[key]

    ema_model.load_state_dict(new_dict)

In [27]:
def get_cut_mask(out, thres=0.5, nms=0):
    probs = F.softmax(out, 1)
    masks = (probs >= thres).type(torch.int64)
    masks = masks[:, 1, :, :].contiguous()
    if nms == 1:
        masks = LargestCC_pancreas(masks)
    return masks

def LargestCC_pancreas(segmentation):
    N = segmentation.shape[0]
    batch_list = []
    for n in range(N):
        n_prob = segmentation[n].detach().cpu().numpy()
        labels = label(n_prob)
        if labels.max() != 0:
            largestCC = labels == np.argmax(np.bincount(labels.flat)[1:])+1
        else:
            largestCC = n_prob
        batch_list.append(largestCC)
    
    return torch.Tensor(batch_list).cuda()

def save_net_opt(net, optimizer, path):
    state = {
        'net': net.state_dict(),
        'opt': optimizer.state_dict(),
    }
    torch.save(state, str(path))

def load_net_opt(net, optimizer, path):
    state = torch.load(str(path))
    net.load_state_dict(state['net'])
    optimizer.load_state_dict(state['opt'])

def load_net(net, path):
    state = torch.load(str(path))
    net.load_state_dict(state['net'])

def get_current_consistency_weight(epoch):
    return args.consistency * sigmoid_rampup(epoch, args.consistency_rampup)

In [28]:
train_data_path = args.root_path 
pre_max_iterations = args.pre_max_iterations
self_max_iterations = args.self_max_iteration 
base_lr = args.base_lr 
CE = nn.CrossEntropyLoss(reduction= 'none')


if args.deterministic:
    cudnn.benchmark = False
    cudnn.deterministic = True
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    random.seed(args.seed)
    np.random.seed(args.seed)

