In [1]:
import os 
import numpy as np 
import random 
import time 
import logging 
from tqdm import tqdm


import torch 
import torch.optim as optim
from torchvision import transforms
import torch.nn.functional as F 
import torch.nn as nn 
from torch.utils.data import DataLoader
from torch.nn.parameter import Parameter


from utils.params import params 
from utils.losses import DiceLoss, softmax_mse_loss, softmax_kl_loss, l_correlation_cos_mean
from networks.utils import BCP_net, get_current_consistency_weight, update_ema_variable
from dataset.basedataset import BaseDataset
from dataset.utils import TwoStreamBatchSampler, RandomGenerator, patients_to_slices

from tensorboardX import SummaryWriter

In [2]:
args = params()

#### 1. ACDC Dataset

In [3]:
# load dataset 
train_db = BaseDataset(
    root_path= args.root_dir, 
    split= 'train', 
    transform= transforms.Compose([RandomGenerator(args.patch_size)])
)

val_db = BaseDataset(
    root_path= args.root_dir, 
    split= 'val'
)


# split to labeled and unlabeled dataset 
labeled_slices = patients_to_slices(args.root_dir, args.label_num) 
label_ratio = round(labeled_slices / len(train_db), 1)* 100
print(f'Number of labeled data in used: {label_ratio}%')
labeled_idxs = list(range(0, labeled_slices) )
unlabeled_idxs = list(range(labeled_slices ,len(train_db)))
batch_sampler = TwoStreamBatchSampler(labeled_idxs, unlabeled_idxs, args.batch_size, args.batch_size - args.labeled_bs)

# Create dataloader
def worker_init_fn(worker_id):
    random.seed(args.seed + worker_id)
 
trainloader = DataLoader(train_db, batch_sampler= batch_sampler, num_workers= 4, pin_memory= True, worker_init_fn= worker_init_fn)
valloader = DataLoader(val_db, batch_size= 1, shuffle= False, num_workers=1)

# Check 
dataiter = iter(trainloader) 
sampled_batch = next(dataiter) 
volume_image, volume_label = sampled_batch['image'], sampled_batch['label']
volume_image, volume_label = volume_image.cuda(), volume_label.cuda()
labeled_volume_batch = volume_image[ : args.labeled_bs]
unlabeled_volume_batch = volume_image[args.labeled_bs :]
print(f'X.shape = {volume_image.shape}')
print(f'Y.shape = {volume_label.shape}')

Number of labeled data in used: 10.0%
X.shape = torch.Size([24, 1, 256, 256])
Y.shape = torch.Size([24, 256, 256])


#### 2. Linear Transform

In [None]:
class Linear_Vector(nn.Module): 
    """
    Implement: 
        - Initialize the linear transform matrix. G.shape = (ndim, ndim) 
        - Apply Gb to performce linear transform 
    Formula: 
        qb.shape = G * w --> (ndim, k)
        w.shape = (ndim, k)
        G.shape = (ndim, ndim) 
    """
    def __init__(self, ndim): 
        super(Linear_Vector, self).__init__()
        self.ndim = ndim 
        self.params = Parameter(torch.Tensor(self.ndim, self.ndim)) # Linear transform matrix
        self.ratio_init = 0.3 
        self.initialize()

    def initialize(self):
        """ 
        Initialize for Standard Scaler 
        mean = 0 
        std = ratio_init 
        """ 
        for param in self.params: 
            param.data.normal_(0, self.ratio_init)  

    def forward(self,x):
        result = torch.mm(self.params, x) # dot product  
        return result 

In [18]:
ndim = 64 
k = 10 
w1 = torch.randn(ndim, k) 
w2 = torch.randn(ndim, k)
linear_trans1 = Linear_Vector(ndim= 64)

q21 = linear_trans1(w2) 
print(q21.shape)

torch.Size([64, 10])


#### 2. Understand Loss 

In [4]:
base_model = BCP_net(in_chns= 1, class_num= args.num_classes) 
outputs = base_model(volume_image)
print(f'Output.shape = {outputs.shape}')
print(f'Type of output: {outputs.dtype}')

Output.shape = torch.Size([24, 4, 256, 256])
Type of output: torch.float32


In [None]:
dice_loss_fn = DiceLoss(n_classes= 4) 
def supervised_loss(outputs, target, alpha= 0.5): 
    """
    Comptute supervised loss for CauSSL (on labeled data only)
    supervised = 0.5 ( CE + DICE )
    """
    # Compute CELoss 
    target = target.long() 
    loss_ce = F.cross_entropy(outputs, target) 

    # Compute DiceLoss 
    loss_dice = dice_loss_fn(outputs, target.unsqueeze(1)) 

    loss = alpha * ( loss_ce + loss_dice)
    return loss 

In [8]:
# Check supervised loss 
outputs = base_model(volume_image)
loss_sup = supervised_loss(outputs[: args.labeled_bs], volume_label[: args.labeled_bs])
print(loss_sup)

tensor(1.1995, device='cuda:0', grad_fn=<MulBackward0>)
