In [3]:
import os 
import numpy as np 
import random 
import time 
import logging 
from tqdm import tqdm


import torch 
import torch.optim as optim
from torchvision import transforms
import torch.nn.functional as F 
import torch.nn as nn 
from torch.utils.data import DataLoader
from torch.nn.parameter import Parameter


from utils.params import params 
from utils.losses import DiceLoss, softmax_mse_loss, softmax_kl_loss, l_correlation_cos_mean
from networks.utils import BCP_net, get_current_consistency_weight, update_ema_variable
from dataset.basedataset import BaseDataset
from dataset.utils import TwoStreamBatchSampler, RandomGenerator, patients_to_slices

from tensorboardX import SummaryWriter

In [4]:
args = params()

#### 1. ACDC Dataset

In [5]:
# load dataset 
train_db = BaseDataset(
    root_path= args.root_dir, 
    split= 'train', 
    transform= transforms.Compose([RandomGenerator(args.patch_size)])
)

val_db = BaseDataset(
    root_path= args.root_dir, 
    split= 'val'
)


# split to labeled and unlabeled dataset 
labeled_slices = patients_to_slices(args.root_dir, args.label_num) 
label_ratio = round(labeled_slices / len(train_db), 1)* 100
print(f'Number of labeled data in used: {label_ratio}%')
labeled_idxs = list(range(0, labeled_slices) )
unlabeled_idxs = list(range(labeled_slices ,len(train_db)))
batch_sampler = TwoStreamBatchSampler(labeled_idxs, unlabeled_idxs, args.batch_size, args.batch_size - args.labeled_bs)

# Create dataloader
def worker_init_fn(worker_id):
    random.seed(args.seed + worker_id)
 
trainloader = DataLoader(train_db, batch_sampler= batch_sampler, num_workers= 4, pin_memory= True, worker_init_fn= worker_init_fn)
valloader = DataLoader(val_db, batch_size= 1, shuffle= False, num_workers=1)

# Check 
dataiter = iter(trainloader) 
sampled_batch = next(dataiter) 
volume_image, volume_label = sampled_batch['image'], sampled_batch['label']
volume_image, volume_label = volume_image.cuda(), volume_label.cuda()
labeled_volume_batch = volume_image[ : args.labeled_bs]
unlabeled_volume_batch = volume_image[args.labeled_bs :]
print(f'X.shape = {volume_image.shape}')
print(f'Y.shape = {volume_label.shape}')

Number of labeled data in used: 10.0%
X.shape = torch.Size([24, 1, 256, 256])
Y.shape = torch.Size([24, 256, 256])


#### 2. Linear Transform

In [6]:
class Linear_vector(nn.Module): 
    """
    Implement: 
        - Initialize the linear transform matrix. G.shape = (ndim, ndim) 
        - Apply Gb to performce linear transform 
    Formula: 
        qb.shape = G * w --> (ndim, k)
        w.shape = (ndim, k)
        G.shape = (ndim, ndim) 
    """
    def __init__(self, ndim): 
        super(Linear_vector, self).__init__()
        self.ndim = ndim 
        self.params = Parameter(torch.Tensor(self.ndim, self.ndim)) # Linear transform matrix
        self.ratio_init = 0.3 
        self.initialize()

    def initialize(self):
        """ 
        Initialize for Standard Scaler 
        mean = 0 
        std = ratio_init 
        """ 
        for param in self.params: 
            param.data.normal_(0, self.ratio_init)  

    def forward(self,x):
        result = torch.mm(self.params, x) # dot product  
        return result 

In [7]:
ndim = 64 
k = 10 
w1 = torch.randn(ndim, k) 
w2 = torch.randn(ndim, k)
linear_trans1 = Linear_vector(ndim= 64)

q21 = linear_trans1(w2) 
print(q21.shape)

torch.Size([64, 10])


#### 3. Understand Loss 

In [8]:
base_model = BCP_net(in_chns= 1, class_num= args.num_classes) 
outputs = base_model(volume_image)
print(f'Output.shape = {outputs.shape}')
print(f'Type of output: {outputs.dtype}')

Output.shape = torch.Size([24, 4, 256, 256])
Type of output: torch.float32


In [9]:
dice_loss_fn = DiceLoss(n_classes= 4) 
def supervised_loss(outputs, target, alpha= 0.5): 
    """
    Comptute supervised loss for CauSSL (on labeled data only)
    supervised = 0.5 ( CE + DICE )
    """
    # Compute CELoss 
    target = target.long() 
    loss_ce = F.cross_entropy(outputs, target) 

    # Compute DiceLoss 
    loss_dice = dice_loss_fn(outputs, target.unsqueeze(1)) 

    loss = alpha * ( loss_ce + loss_dice)
    return loss 

#### 4. Training process 

In [10]:
num_classes = args.num_classes
base_lr = args.base_lr 
labeled_bs = args.labeled_bs
max_iterations = args.max_iteration

# create 2 model with the same architecure
model1 = BCP_net(in_chns=1, class_num= 4) 
model2 = BCP_net(in_chns=1, class_num= 4) 
optimizer1 = optim.SGD(model1.parameters(), base_lr,  momentum= 0.9, weight_decay= 1e-4)
optimizer2 = optim.SGD(model2.parameters(), base_lr, momentum= 0.9, weight_decay= 1e-4)

model1.train() 
model2.train() 

# Initialize linear transform matrix (vector)
linear_params1 = [] 
linear_params2 = [] 
count = 0
for name, parameters in model1.named_parameters(): 
    if 'conv' in name and 'weight' in name: 
        if len(parameters.shape) == 4: 
            count += 1 
            outdim = parameters.shape[0] 
            linear_params1.append(Linear_vector(outdim))
            linear_params2.append(Linear_vector(outdim))

# Convert from list to torch
linear_params1 = nn.ModuleList(linear_params1)
linear_params2 = nn.ModuleList(linear_params2)    
linear_params1 = linear_params1.cuda() 
linear_params2 = linear_params2.cuda()

linear_optimizer1 = optim.Adam(linear_params1.parameters(), 2e-2) # Need consider about this hyper-parameters
linear_optimizer2 = optim.Adam(linear_params2.parameters(), 2e-2)

if args.consistency_type == 'mse': 
    consistency_criterion = softmax_mse_loss
elif args.consistency_type == 'kl': 
    consistency_criterion = softmax_kl_loss
else: 
    assert False, args.consistency_type

# Training process - Cross Pseudo Supervision FrameWork 
writer = SummaryWriter() 
logging.info(f'{len(trainloader)} per epoch')

iter_num = 0 
iter_num_max = 0 
max_epoch = max_iterations // len(trainloader) + 1 
lr_ = base_lr 
model1.train() 
model2.train() 
for epoch_num in tqdm(range(max_epoch), ncols=70): 
    time1 = time.time() 
    for i_batch, sampled_batch in enumerate(trainloader): 
        time2 = time.time() 

        # Update linear transform matrix periodly 
        if iter_num > args.start_step1 and iter_num % args.min_step == 0: 
            icm_loss1 = -l_correlation_cos_mean(model1, model2, linear_params1)
            icm_loss2 = -l_correlation_cos_mean(model1, model2, linear_params2)

            linear_optimizer1.zero_grad() 
            linear_optimizer2.zero_grad() 

            icm_loss1.backward() 
            icm_loss2.backward() 
            linear_optimizer1.step() 
            linear_optimizer2.step() 

            iter_num_max += 1 

            writer.add_scalar('loss/icm_loss1_max', -icm_loss1, iter_num_max)
            writer.add_scalar('loss/icm_loss2_max', -icm_loss2, iter_num_max)
        
        volume_batch, label_batch = sampled_batch['image'], sampled_batch['label']
        volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda() 
        unlabeled_batch = volume_batch[labeled_bs :]

        outputs1 = model1(volume_batch)
        outputs2 = model2(volume_batch) 

        supervised_loss1 = supervised_loss(outputs1[: labeled_bs], volume_batch[labeled_bs])
        supervised_loss2 = supervised_loss(outputs2[: labeled_bs], label_batch[: labeled_bs])
        print(supervised_loss1)
        break 
    break 




  0%|                                          | 0/10 [00:00<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 192.00 MiB. GPU 0 has a total capacity of 7.79 GiB of which 66.88 MiB is free. Including non-PyTorch memory, this process has 7.66 GiB memory in use. Of the allocated memory 7.45 GiB is allocated by PyTorch, and 46.73 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)