In [1]:
import numpy as np 
import matplotlib.pyplot as plt
import h5py 
import os 
import random
import itertools
import math
import sys
from skimage.measure import label
from scipy.ndimage import zoom, rotate 
from medpy import metric 
from tqdm import tqdm 
import logging 
import os 


import torch 
import torch.nn as nn 
import torch.nn.functional as F 
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.sampler import Sampler 
import torchvision.transforms as transforms 
import torch.optim as optim 
from tensorboardX import SummaryWriter

In [2]:
# Params 
class params: 
    def __init__(self):
        self.root_path = '/kaggle/input/la-dataset/LA'
        self.exp = 'BCP' 
        self.model = 'VNet'
        self.pre_max_iterations = 10
        self.self_max_iteration = 100 
        self.max_samples = 80 
        self.labeled_bs = 4
        self.batch_size = 8 
        self.base_lr = 0.01 
        self.deterministic = 1 
        self.labelnum = 8 
        self.consistency = 1.0 
        self.consistency_rampup = 40.0 
        self.magnitude = 10.0 
        self.seed = 10
        self.gpu = '0'
    
        # Setting of BCP 
        self.u_weight = 0.5 
        self.mask_ratio = 2/3 
        self.patch_size = [112, 112, 80]

        # Setting of mixup 
        self.u_alpha = 2.0 
        self.loss_weight = 0.5


args = params()

#### 1. Dataset 

In [3]:
class BaseDataset(Dataset): 
    def __init__(self, root_path, split= 'train', transform= None, num= None): 
        """
        Use to load name of file.list 

        """
        self.root_path = root_path
        self.split = split 
        self.transform = transform
        self.sample_list = []
        
        # Select dataset ( ACDC or LA )
        if self.root_path == 'ACDC': 
            train_list = os.path.join(self.root_path, 'train_slices.list')
            val_list = os.path.join(self.root_path, 'val.list')
            self.train_files = os.path.join(self.root_path, 'data/slices')
            self.valid_files = os.path.join(self.root_path, 'data')
        elif self.root_path == 'LA': 
            train_list = os.path.join(self.root_path, 'train.list')
            self.train_files = os.path.join(self.root_path, '2018LA_Seg_Training Set' )
        
        modes = ['train', 'val'] if self.root_path == 'ACDC' else ['train']
        if self.split in modes: 
            # Create sample_list 
            if self.split == 'train': 
                with open(train_list, 'r') as file: 
                    self.sample_list = file.readlines() 
                
                self.sample_list = [item.replace('\n', '') for item in self.sample_list]
            elif self.split == 'val': 
                with open(val_list, 'r') as file: 
                    self.sample_list = file.readlines() 
                
                self.sample_list = [item.replace('\n','') for item in self.sample_list]
        else: 
            raise ValueError(f'Mode: {self.split} is not supported')
        
        # Use number of dataset only 
        if isinstance(num, int): 
            self.sample_list = self.sample_list[:num]
        
        # print(f'Total slices: {len(self.sample_list)}')
        

    def __len__(self): 
        return len(self.sample_list)

    def __getitem__(self, index):
        case = self.sample_list[index]
        # Prepare the file_path
        if self.root_path == 'ACDC':
            if self.split == 'train':  
                file_path = os.path.join(self.train_files, f'{case}.h5')
            else: 
                file_path = os.path.join(self.valid_files, f'{case}.h5')
        elif self.root_path == 'LA': 
            file_path = os.path.join(self.train_files, f'{case}/mri_norm2.h5')

        # Extract data
        h5f = h5py.File(file_path, 'r')
        image = h5f['image'][:]
        label = h5f['label'][:]
        sample = {'image': image, 'label': label}

        # Apply transform 
        if self.split == 'train' and self.transform is not None: 
            sample = self.transform(sample)
        sample['case'] = case 
        return sample 

In [11]:
def random_rot_flip(image, label): 
    # rotate 
    k = np.random.randint(0, 4)
    image = np.rot90(image, k) 
    label = np.rot90(label, k) 
    # flip 
    axis = np.random.randint(0, 2)
    image = np.flip(image, axis) 
    label = np.flip(label, axis) 
    return image, label 


class RandomRotFlip: 
    def __call__(self, sample): 
        image, label = sample['image'], sample['label']
        image, label = random_rot_flip(image, label)
        sample = {'image': image, 'label': label}
        return sample
    
class RandomCrop: 
    def __init__(self, output_size, padding_margin = 3 ): 
        self.output_size = output_size 
        self.padding_margin = padding_margin 

    def __call__(self, sample): 
        image, label = sample['image'] , sample['label'] # (112 ,80, 80)

        # Padding if need 
        if label.shape[0] < self.output_size[0] or label.shape[1] < self.output_size[1] or label.shape[2] < self.output_size[2]: 
            w_pad = max((self.output_size[0] - label.shape[0]) // 2 + self.padding_margin, 0) 
            h_pad = max((self.output_size[1] - label.shape[1]) // 2 + self.padding_margin, 0)
            d_pad = max((self.output_size[2] - label.shape[2]) // 2 + self.padding_margin, 0) 
            print(f'({w_pad, h_pad, d_pad})')
            image = np.pad(image, [(w_pad, w_pad), (h_pad, h_pad), (d_pad, d_pad)], mode= 'constant', constant_values= 0)
            label = np.pad(label, [(w_pad, w_pad), (h_pad, h_pad), (d_pad, d_pad)], mode= 'constant', constant_values= 0) 

        # Random crop 
        w, h, d = image.shape 
        w1 = np.random.randint(0, w - self.output_size[0])
        h1 = np.random.randint(0, h - self.output_size[1])
        d1 = np.random.randint(0, d - self.output_size[2])

        image = image[w1 : w1 + self.output_size[0], h1 : h1 + self.output_size[1], d1 : d1 + self.output_size[2]]
        label = label[w1 : w1 + self.output_size[0], h1 : h1 + self.output_size[1], d1 : d1 + self.output_size[2]]

        sample = {'image': image, 'label': label} 
        return sample

class ToTensor: 
    def __call__(self, sample):
        image, label = sample['image'].copy(), sample['label'].copy() # how the fuck i know ?
        image = np.reshape(image, (1, image.shape[0], image.shape[1], image.shape[2])).astype(np.float32)

        return {
            'image': torch.from_numpy(image).to(torch.float32), 
            'label': torch.from_numpy(label).to(torch.int64)
        }

In [5]:
def iterate_once(indices): 
    """
    Permutate the iterable once 
    (permutate the labeled_idxs once)
    """
    return np.random.permutation(indices) 

def iterate_externally(indices): 
    """
    Create an infinite iterator that repeatedly permutes the indices.
    ( permutate the unlabeled_idxs to make different)
    """
    def infinite_shuffles(): 
        while True: 
            yield np.random.permutation(indices)
            
    return itertools.chain.from_iterable(infinite_shuffles())

def grouper(iterable, n): 
    args = [iter(iterable)] * n 
    return zip(*args)

class TwoStreamBatchSampler(Sampler): 
    def __init__(self, primary_indicies, secondary_indicies, batchsize, secondary_batchsize): 
        self.primary_indicies = primary_indicies
        self.secondary_indicies = secondary_indicies
        self.primary_batchsize = batchsize - secondary_batchsize
        self.secondary_batchsize = secondary_batchsize

        assert len(self.primary_indicies) >= self.primary_batchsize > 0 
        assert len(self.secondary_indicies) >= self.secondary_batchsize > 0 

    def __iter__(self):
        primary_iter = iterate_once(self.primary_indicies)
        secondary_iter = iterate_externally(self.secondary_indicies)

        return (
            primary_batch + secondary_batch
            for (primary_batch, secondary_batch) 
            in zip(grouper(primary_iter, self.primary_batchsize),
                   grouper(secondary_iter, self.secondary_batchsize))
        )

    def __len__(self): 
        return len(self.primary_indicies) // self.primary_batchsize

#### 2. Loss function 

In [24]:
def to_one_hot(tensor, num_classes): 
    """
    Params:
        - tensor (torch.Tensor): NxDxHxW 
        - num_classes (int): num classes 
    """
    assert tensor.max().item() < num_classes 
    assert tensor.min().item() >= 0 

    size = list(tensor.size())
    assert size[1] == 1 
    size[1] = num_classes 
    one_hot = torch.zeros(*size)

    if tensor.is_cuda: 
        one_hot = one_hot.to(tensor.device) 
    
    one_hot = one_hot.scatter_(1, tensor, 1) # create onehot form ? 
    return one_hot 

def get_probability(logits): 
    """
    Params: 
        - logits (torch.Tensor): prediction of model
    """
    size = logits.size() 
    if size[1] > 1: 
        pred = F.softmax(logits, dim= 1) 
        nclass = size[1] 
    else: # is it necessary ?  
        pred = F.sigmoid(logits) 
        pred = torch.cat([1 - pred, pred], dim= 1) 
    
    return pred, nclass 


class DiceLoss3D(nn.Module):
    def __init__(self, num_classes, class_weights= None, smooth= 1e-5): 
        super(DiceLoss3D, self).__init__()
        self.num_classes = num_classes
        self.smooth = smooth 
        if class_weights is None: 
            self.class_weights = nn.Parameter(torch.ones(1, num_classes).type(torch.float32), requires_grad= False)
        else: 
            class_weights = np.array(class_weights) 
            self.class_weights = nn.Parameter(torch.tensor(class_weights, dtype= torch.float32), requires_grad= False) 
    
    #TODO: I seee it dont nesscessary  
    def prob_forward(self): 
        pass 
        
    def forward(self, logits, target, mask= None): 
        size = logits.size()
        N, nclass = size[0], size[1]

        logits = logits.view(N, 1, -1) 
        target = target.view(N, 1, -1) 

        pred_one_hot, nclass = get_probability(logits) 
        target_one_hot = to_one_hot(target.type(torch.long), nclass).type(torch.float32) 

        inter = pred_one_hot * target_one_hot
        union = pred_one_hot + target_one_hot 

        if mask is not None: 
            mask = mask.view(N, 1, -1) 
            inter = (inter.view(N, 1, -1) * mask).sum(2) 
            union = (union.view(N, 1, -1) * mask).sum(2) 
        else: 
            inter = inter.view(N, 1, -1).sum(2) 
            union = union.view(N, 1, -1).sum(2) 
        
        dice = (2 * inter + self.smooth) / (union + self.smooth) 
        return 1 - dice
    


In [25]:
DICE = DiceLoss3D(num_classes= 2) 
CE = nn.CrossEntropyLoss(reduction= 'none')
def mix_loss(output, img_l, patch_l, mask, l_weight = 1.0, u_weight = 0.5, unlab = False): 
    img_l, patch_l = img_l.type(torch.int64), patch_l.type(torch.int64)
    image_weight, patch_weight = l_weight, u_weight 
    if unlab: 
        image_weight, patch_weight = u_weight, l_weight 
    
    patch_mask = 1 - mask 
    loss_dice = DICE(output, img_l, mask) * image_weight
    loss_dice += DICE(output, patch_l, patch_mask) * patch_weight

    loss_ce = image_weight * (CE(output, img_l) * mask).sum() / (mask.sum() + 1e-16)
    loss_ce += patch_weight * (CE(output, patch_l) * patch_mask).sum() / (patch_mask.sum() + 1e-16)
    loss = (loss_dice + loss_ce) / 2 
    return loss 


#### 3. Validation 

In [None]:
def test_single_case(model, image, stride_xy, stride_z, patch_size , num_classes= 1): 
    w, h, d = image.shape 
    pad_w, pad_h, pad_d = max(0, patch_size[0] - w), max(0, patch_size[1] -h), max(0, patch_size[2] - d)
    pad = [(pad_w // 2, pad_w - pad_w //2), 
           (pad_h // 2, pad_h - pad_h //2), 
           (pad_d // 2, pad_d - pad_d //2)]
    
    if any(p > 0 for p in [pad_w, pad_h, pad_d]): 
        image = np.pad(image, pad, mode= 'constant', constant_values= 0) 
    
    ww, hh, dd = image.shape 

    # Calculate the number of patches along each dimension 
    sx = math.ceil((ww - patch_size[0]) / stride_xy) + 1 
    sy = math.ceil((hh - patch_size[1]) / stride_xy) + 1 
    sz = math.ceil((dd - patch_size[2]) / stride_z) + 1 

    # Initialize score map and count map 
    score_map = np.zeros((num_classes, ww, hh, dd), dtype= np.float32) 
    cnt = np.zeros((ww, hh, dd), dtype= np.float32) 

    # Perform sliding-window inference 
    for x in range(sx): 
        xs = min(stride_xy *x , ww - patch_size[0])
        for y in range(y): 
            ys = min(stride_xy * y, hh - patch_size[1])
            for z in range(sz): 
                zs = min(stride_z * z, dd - patch_size[2]) 
                

def var_all_case_LA(model, num_classes, patch_size = (112, 112, 80), stride_xy= 18, stride_z = 4): 
    pass 