In [15]:
import os
import argparse
import torch
import cv2
import seaborn as sns
import random
import numpy as np

from torchvision import transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
import matplotlib.pyplot as plt

### UNet

In [9]:
from unet import UNet
from dataset import BaseDataSets, RandomGenerator, TwoStreamBatchSampler

### Arguements

In [12]:
# Params 
class params: 
    def __init__(self): 
        self.root_dir = '/kaggle/input/acdc-dataset/ACDC'
        self.saved_path = ''
        self.exp = 'BCP' 
        self.model = 'unet' 
        self.pretrain_iterations = 200
        
        self.selftrain_iterations = 200
        self.batch_size = 24
        self.deterministic = 1 
        self.base_lr = 0.01 
        self.patch_size = [256,256] 
        self.seed = 42 
        self.num_classes = 4 
        self.stage_name = 'selftrain'

        # label and unlabel 
        self.labeled_bs = 2
        self.label_num = 7 
        self.u_weight = 0.5 

        # Cost 
        self.gpu = '0' 
        self.consistency = 0.1
        self.consistency_rampup = 200.0 
        self.magnitude = '6.0' 
        self.s_param = 6 


args = params()

In [13]:
def net_factory(net_type="unet", in_chns=1, class_num=2, mode = "train", tsne=0):
    if net_type == "unet" and mode == "train":
        net = UNet(in_chns=in_chns, class_num=class_num).cuda()
    return net

In [None]:
model_name = ["BCP"]
BCP_model_path = args.saved_path + f'labeled_{args.label_num}/BCP.pth'

limit_pixels = 500
# Why 0.5?: Likely chosen empirically as a balance between smoothness and detail.
bandwidth_adjust = 0.5 
line_wid = 5

In [None]:
def get_ACDC_masks(output):
    probs = F.softmax(output, dim=1)
    _, probs = torch.max(probs, dim=1)
    return probs

def patients_to_slices(dataset, patients_num)
    ref_dict = None
    if "ACDC" in dataset:
        ref_dict = {"1": 32, "3": 68, "7": 136,
                    "14": 256, "21": 396, "28": 512, "35": 664, "70": 1312}
    elif "Prostate":
        ref_dict = {"2": 27, "4": 53, "8": 120,
                    "12": 179, "16": 256, "21": 312, "42": 623}
    else: 
        print("Error")
    return ref_dict[str(patients_num)]

In [None]:
def plot_kde(BCP_feature, BCP_pred, labels, specific_c, f_dim, pic_num):
    total_pixel, total_fdim = BCP_feature.shape[0], BCP_feature.shape[1]
    labeled_pixel = int(total_pixel / 2) + 1
    save_path = f"KDE/ACDC/{f_dim}/labeled_{args.label_num}/class_{specific_c}"
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    
    # Choose the specific class
    l_pred = np.where(BCP_pred[:labeled_pixel, :]==specific_c)
    u_pred = np.where(BCP_pred[labeled_pixel:, :]==specific_c)
    
    l_lab = np.where(labels[:labeled_pixel, :]==specific_c)
    u_lab = np.where(BCP_pred[labeled_pixel:, :]==specific_c)
    
    correct_cor_l = np.intersect1d(l_pred[0], l_lab[0])
    correct_cor_u = np.intersect1d(u_pred[0], u_lab[0]) + labeled_pixel
    
    pixel_num = min(len(correct_cor_l), len(correct_cor_u), limit_pixels)
    
    print(f"Total {pixel_num} pixels for class {specific_c}")
    BCP_feature_l = np.mean(BCP_feature[correct_cor_l[:pixel_num]], axis=1)
    BCP_feature_u = np.mean(BCP_feature[correct_cor_u[:pixel_num]], axis=1)
    
    method_name_list = ["BCP"]
    feature_list = [BCP_feature_l, BCP_feature_u]
    
    plt.figure()
    fig = plt.figure(figsize=(29, 4))
    sns.set_context("notebook", font_scale=2)
    
    for i in range(0, 1):
        plt.subplot(1, 1, i+1)
        plt.subplots_adjust(left=None, bottom=None, right=None, top=None, 
                            wspace=0.3, hspace=None)
        sns.kdeplot(feature_list[0], bw_adjust=bandwidth_adjust, color='g', line_width=line_wid)
        sns.kdeplot(feature_list[1], bw_adjust=bandwidth_adjust, color='b', line_width=line_wid)
        plt.xticks(size=16)
        plt.yticks(size=16)
        plt.ylabel(" ")
        plt.title(method_name_list[i])
        
    plt.savefig(f"KDE/ACDC/{f_dim}/labeled_{args.label_num}/class_{specific_c}/kde_test_mean{pic_num}_{args.label_num}_{specific_c}.png")
    print(f"Save to: KDE/ACDC/{f_dim}/labeled_{args.label_num}/class_{specific_c}/kde_test_mean{pic_num}_{args.label_num}_{specific_c}.png")
    plt.clf()

In [None]:
def Inference(args):
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    BCP_Net = net_factory(net_type=args.model, in_chns=1, class_num=args.num_classes, mode="test")
    
    BCP_Net.load_state_dict(torch.load(BCP_model_path))
    print("Init models' weight successfully")
    BCP_Net.eval()
    
    def worker_init_fn(worker_id):
        random.seed(1337 + worker_id)
    
    db_train = BaseDataSets(base_dir=args.root_dir,
                            split='train',
                            num=None,
                            transform=transforms.Compose([RandomGenerator(args.patch_size)]))
    total_slices = len(db_train)
    labeled_slice = patients_to_slices(args.root_dir, args.label_num)
    print("Total slices is: {}, labeled slices is:{}".format(total_slices, labeled_slice))
    labeled_idx = list(range(0, labeled_slice))
    unlabeled_idxs = list(range(labeled_slice, total_slices))
    batch_sampler = TwoStreamBatchSampler(labeled_idx, unlabeled_idxs, args.batch_size, args.batch_size-args.labeled_bs)
    trainloader = DataLoader(db_train, batch_sampler=batch_sampler, num_workers=4, pin_memory=True, worker_init_fn=worker_init_fn)
    picture_number = 0
    
    for epoch in range(3):
        for _, sampled_batch in enumerate(trainloader):
            
            volume_batch, label_batch = sampled_batch['image'], sampled_batch['label']
            volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda()
            label_batch = label_batch.detach().cpu().numpy()
            
            pred, BCP_feature = BCP_Net(volume_batch)
            
            B_pred = get_ACDC_masks(pred)
            
            f_dim, x_, y_ = BCP_feature.shape[1], BCP_feature.shape[2], BCP_feature.shape[3]
            BCP_feature = BCP_feature.permute(0, 2, 3, 1).contiguous()
            BCP_feature = BCP_feature.view(-1, f_dim) # 1000, 16
            
            resized_label = np.zeros((args.batch_size, x_, y_))
            for i in range(args.batch_size):
                resized_label[i,] = cv2.resize(label_batch[i,].squeeze(), (x_, y_))
            label_batch = torch.from_numpy(resized_label).cuda()
            label = label_batch.view(-1, 1) # a (3, 1) b[a, :]
            BCP_pred = B_pred.view(-1, 1)

            BCP_feature = BCP_feature.detach().cpu().numpy()
            BCP_pred = BCP_pred.detach().cpu().numpy()
            label = label.detach().cpu().numpy()
            spi_c = 2
            plot_kde(BCP_feature, BCP_pred, label, spi_c, f_dim, picture_number)
            picture_number += 1