In [1]:
import warnings
warnings.filterwarnings('ignore')

In [2]:
import numpy as np 
import matplotlib.pyplot as plt 
import os 
import logging
import random
from tqdm import tqdm
from skimage.measure import label 

import torch 
import torch.nn as nn 
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from torchvision.datasets import MNIST


# module 
from networks.net_factory import BCP_net
from dataset.basedataset import ACDCDataset
from dataset.utils import RandomGenerator, ACDC_patients_to_slices, TwoStreamBatchSampler
from utils.params import params 
from utils.masks import generate_mask, random_mask, contact_mask
from utils.losses import mix_loss
from utils.valid2d import test_single_volume
from networks.utils import save_net_opt, load_net_opt, load_net, get_current_consistency_weight, update_model_ema
from MAE.maeLoss import reconstruction_loss
from MAE.maskImage import mask_image

# writer 
from torch.utils.tensorboard import SummaryWriter

2025-04-26 14:52:52.714495: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-04-26 14:52:52.728520: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-04-26 14:52:52.732582: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-04-26 14:52:52.743547: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:
args = params() 

In [4]:
def get_ACDC_2DLargestCC(segmentation):
    batch_list = []
    N = segmentation.shape[0]
    for i in range(0, N):
        class_list = []
        for c in range(1, 4):
            temp_seg = segmentation[i] #== c *  torch.ones_like(segmentation[i])
            temp_prob = torch.zeros_like(temp_seg)
            temp_prob[temp_seg == c] = 1
            temp_prob = temp_prob.detach().cpu().numpy()
            labels = label(temp_prob)          
            if labels.max() != 0:
                largestCC = labels == np.argmax(np.bincount(labels.flat)[1:])+1
                class_list.append(largestCC * c)
            else:
                class_list.append(temp_prob)
        
        n_batch = class_list[0] + class_list[1] + class_list[2]
        batch_list.append(n_batch)

    return torch.Tensor(batch_list).cuda()
    
def get_ACDC_masks(output, nms=0):
    probs = F.softmax(output, dim=1)
    _, probs = torch.max(probs, dim=1)
    if nms == 1:
        probs = get_ACDC_2DLargestCC(probs)      
    return probs

#### 1. BCP network

In [5]:
def pretrain(args, snapshot_path): 
    base_lr = args.base_lr
    num_classes = args.num_classes
    max_iterations = args.pretrain_iterations
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 
    labeled_bs = args.labeled_bs 

    # Load data 
    def worker_init_fn(worker_id):
        random.seed(args.seed + worker_id) 

    label_sub_bs, unlabeled_sub_bs = int(args.labeled_bs / 2), int((args.batch_size - args.labeled_bs)/2) 
    db_train = ACDCDataset(base_dir= args.root_dir, split= 'train', 
                        transform= transforms.Compose([RandomGenerator(args.patch_size)]))
    db_val = ACDCDataset(base_dir= args.root_dir, split= 'val')
    total_slices = len(db_train)
    labeled_slices = ACDC_patients_to_slices(args.root_dir, args.label_num)
    print(f'Total slices is: {total_slices}, labeled slices is: {labeled_slices}')
    labeled_idxes = list(range(0, labeled_slices))
    unlabeled_idxes = list(range(labeled_slices, total_slices))
    batch_sampler = TwoStreamBatchSampler(labeled_idxes, unlabeled_idxes, args.batch_size, args.batch_size - args.labeled_bs)

    trainloader = DataLoader(db_train, batch_sampler= batch_sampler, num_workers= 4, pin_memory= True, worker_init_fn= worker_init_fn)
    valloader = DataLoader(db_val, batch_size=1, shuffle= False, num_workers= 1)

    # Model 
    model = BCP_net(in_chns=1, num_classes= num_classes)
    optimizer = optim.SGD(model.parameters(), lr= base_lr, momentum= 0.9, weight_decay= 0.0001)

    writer = SummaryWriter(snapshot_path + '/log')
    logging.info('Start pre_training')
    logging.info(f'{len(trainloader)} iterations per epoch')

    model.train() 
    iter_num = 0 
    max_epoch = max_iterations // len(trainloader) + 1 
    best_performance = 0.0 
    best_hd = 100 
    iterator = tqdm(range(max_epoch), ncols= 70) 
    for _ in iterator: 
        for _, sampled_batch in enumerate(trainloader): 
            volume_batch, label_batch = sampled_batch['image'], sampled_batch['label']
            volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda() 

            # Create BCP image 
            img_a, img_b = volume_batch[: label_sub_bs], volume_batch[label_sub_bs : labeled_bs]
            lab_a, lab_b = label_batch[: label_sub_bs], label_batch[label_sub_bs: labeled_bs]
            img_mask, loss_mask =  generate_mask(img_a)
            gt_mixl = lab_a * img_mask + lab_b * (1 - img_mask)

            # Feed to model 
            net_input = img_a * img_mask + img_b * (1 - img_mask)
            out_mixl = model(net_input, mode = 'seg')
            loss_dice, loss_ce = mix_loss(out_mixl, lab_a, lab_b, loss_mask, u_weight= 1.0, unlab= True)

            loss = (loss_dice + loss_ce ) / 2 
            optimizer.zero_grad() 
            loss.backward() 
            optimizer.step() 
            iter_num += 1 

            writer.add_scalar('info/total loss', loss, iter_num)
            writer.add_scalar('info/mix_dice', loss_dice, iter_num)
            writer.add_scalar('info/mix_ce', loss_ce, iter_num)

            logging.info(f'iteration: {iter_num}, loss: {loss}, mix_dice: {loss_dice}, mix_ce: {loss_ce}')
            if iter_num % 20 == 0:
                image = net_input[1, 0:1, :, :]  # shape = [1, H, W]
                image = (image - image.min()) / (image.max() - image.min() + 1e-5)  # normalize to [0,1]
                writer.add_image('pre_train/Mixed_Image', image, iter_num)
                outputs = torch.argmax(torch.softmax(out_mixl, dim=1), dim=1, keepdim=True)
                writer.add_image('pre_train/Mixed_Prediction', outputs[1, ...] * 50, iter_num)
                labs = gt_mixl[1, ...].unsqueeze(0) * 50
                writer.add_image('pre_train/Mixed_GroundTruth', labs, iter_num)

            if iter_num > 0 and iter_num % 200 == 0:
                model.eval()
                metric_list = 0.0
                for _, sampled_batch in enumerate(valloader):
                    metric_i = test_single_volume(sampled_batch["image"], sampled_batch["label"], model, classes=num_classes)
                    metric_list += np.array(metric_i)
                metric_list = metric_list / len(db_val)
                for class_i in range(num_classes-1):
                    writer.add_scalar('info/val_{}_dice'.format(class_i+1), metric_list[class_i, 0], iter_num)
                    writer.add_scalar('info/val_{}_hd95'.format(class_i+1), metric_list[class_i, 1], iter_num)

                performance = np.mean(metric_list, axis=0)[0]
                writer.add_scalar('info/val_mean_dice', performance, iter_num)

                if performance > best_performance:
                    best_performance = performance
                    save_mode_path = os.path.join(snapshot_path, 'iter_{}_dice_{}.pth'.format(iter_num, round(best_performance, 4)))
                    save_best_path = os.path.join(snapshot_path,'{}_best_model.pth'.format(args.model))
                    save_net_opt(model, optimizer, save_mode_path)
                    save_net_opt(model, optimizer, save_best_path)

                logging.info('iteration %d : mean_dice : %f' % (iter_num, performance))
                model.train()

            if iter_num >= max_iterations:
                break
        if iter_num >= max_iterations:
            iterator.close()
            break
    writer.close()


In [None]:
def selftrain(args, pretrain_snapshot_path, snapshot_path): 
    # Extract params
    base_lr = args.base_lr
    num_classes = args.num_classes
    max_iterations = args.selftrain_iterations
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 
    pretrained_model = os.path.join(pretrain_snapshot_path, f'{args.model}_best_model.pth')
    labeled_bs = args.labeled_bs
    labeled_sub_bs, unlabeled_sub_bs = int(args.labeled_bs / 2), int((args.batch_size - args.labeled_bs) /2 )

    # Load data 
    def worker_init_fn(worker_id): 
        random.seed(args.seed + worker_id)

    db_train = ACDCDataset(base_dir= args.root_dir, split= 'train', 
                        transform= transforms.Compose([RandomGenerator(args.patch_size)]))
    db_val = ACDCDataset(base_dir= args.root_dir, split= 'val')
    total_slices = len(db_train)
    labeled_slices = ACDC_patients_to_slices(args.root_dir, args.label_num)
    print(f'Total slices: {total_slices}, Labeled slices: {labeled_slices}')
    labeled_idxs = list(range(0, labeled_slices))
    unlabeled_idxs = list(range(labeled_slices, total_slices))
    batch_sampler = TwoStreamBatchSampler(labeled_idxs, unlabeled_idxs, args.batch_size, args.batch_size - args.labeled_bs)
    trainloader = DataLoader(db_train, batch_sampler= batch_sampler, num_workers= 4, pin_memory= True)
    valloader = DataLoader(db_val, batch_size= 1, shuffle= False, num_workers= 1)

    # Model 
    model = BCP_net(in_chns=1, num_classes= 4)
    ema_model = BCP_net(in_chns= 1, num_classes= 4, ema= True)
    optimizer = optim.SGD(model.parameters(), lr= base_lr, momentum= 0.9, weight_decay= 1e-4)
    load_net(ema_model, pretrained_model)
    load_net_opt(model, optimizer, pretrained_model)
    logging.info(f'Loaded from {pretrained_model}')
    writer = SummaryWriter(snapshot_path + '/log')
    logging.info('Start self-training')
    logging.info(f'{len(trainloader)} iterations per epoch')

    model.train() 
    ema_model.train() 
    iter_num = 0 
    max_epoch = max_iterations // len(trainloader) + 1 
    best_performance = 0.0 
    best_hd = 100.0 
    iterator = tqdm(range(0, max_epoch), ncols= 70)
    for _ in iterator: 
        for _, sampled_batch in enumerate(trainloader): 
            volume_batch, label_batch = sampled_batch['image'], sampled_batch['label']
            volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda() 

            # BCP augmentation 
            img_a, img_b = volume_batch[: labeled_sub_bs], volume_batch[labeled_sub_bs : labeled_bs]
            lab_a, lab_b = label_batch[: labeled_sub_bs], label_batch[labeled_sub_bs : labeled_bs]
            uimg_a, uimg_b = volume_batch[labeled_bs : labeled_bs + unlabeled_sub_bs], volume_batch[labeled_bs + unlabeled_sub_bs :]
            ulab_a, ulab_b = label_batch[labeled_bs : labeled_bs + unlabeled_sub_bs], label_batch[labeled_bs + unlabeled_sub_bs :]
            with torch.no_grad(): 
                pre_a = ema_model(uimg_a, mode = 'seg')
                pre_b = ema_model(uimg_b, mode = 'seg')
                plab_a = get_ACDC_masks(pre_a, nms= 1)
                plab_b = get_ACDC_masks(pre_b, nms= 1)
                img_mask, loss_mask = generate_mask(img_a)
                unl_label = plab_a * img_mask + lab_a * (1 - img_mask) # TODO: Problem !! 
                l_label = lab_b * img_mask + plab_b * (1 - img_mask) #TODO: Problem !!! 
            consistency_weight = get_current_consistency_weight(args, iter_num // 150) # ADJUST
            # ---------------------- Segmentation ----------------------------------- # 
            net_input_unl = uimg_a * img_mask + img_a * ( 1 - img_mask)
            net_input_l = img_b * img_mask + uimg_b * (1 - img_mask)
            out_unl = model(net_input_unl, mode='seg')
            out_l = model(net_input_l, mode='seg')
            unl_dice, unl_ce = mix_loss(out_unl, plab_a, lab_a, loss_mask, u_weight= args.u_weight, unlab= True)
            l_dice, l_ce = mix_loss(out_l, lab_b, plab_b, loss_mask, u_weight= args.u_weight)

            loss_ce = unl_ce + l_ce 
            loss_dice = unl_dice + l_dice 
            loss_bcp = (loss_ce + loss_dice) / 2 
            # ----------------------------------- Reconstruction task --------------------------------- # 
            masked_img, mask = mask_image(volume_batch, block_size= 5, mask_ratio= 0.5)
            masked_img = masked_img.cuda() 
            mask = mask.cuda() 
            
            out_recon =  model(masked_img, mode = 'recon')
            rec_loss = reconstruction_loss(out_recon, volume_batch, mask)
            
            # --------------------------------------- Backward -------------------------------------  # 
            loss = args.bcp_weight * loss_bcp 
            optimizer.zero_grad() 
            loss.backward() 
            optimizer.step() 
            iter_num += 1 
            update_model_ema(model, ema_model, 0.99)

            writer.add_scalar('info/total_loss', loss, iter_num)
            writer.add_scalar('info/mix_dice', loss_dice, iter_num)
            writer.add_scalar('info/mix_ce', loss_ce, iter_num)
            writer.add_scalar('info/consistency_weight', consistency_weight, iter_num)   
            # writer.add_scalar('info/recon_loss', rec_loss.item(), iter_num)
            # logging.info(f'iteration: {iter_num}, mix_dice: {loss_dice}, mix_ce: {loss_ce}, rec_loss: {rec_loss}')
                
            if iter_num % 20 == 0:
                # Compare the reconstruct task 
                gt_image = volume_batch[1, 0:1].detach().cpu()
                gt_image = (gt_image - gt_image.min()) / (gt_image.max() - gt_image.min())
                writer.add_image('train/Reconstruction_GT', gt_image, iter_num)

                # recon_image = out_recon[1, 0:1].detach().cpu()
                # recon_image = (recon_image - recon_image.min()) / (recon_image.max() - recon_image.min() + 1e-5)
                # writer.add_image('train/Reconstruction', recon_image, iter_num)

                image = net_input_unl[1, 0:1, :, :]
                writer.add_image('train/Un_Image', image, iter_num)
                outputs = torch.argmax(torch.softmax(out_unl, dim=1), dim=1, keepdim=True)
                writer.add_image('train/Un_Prediction', outputs[1, ...] * 50, iter_num)
                labs = unl_label[1, ...].unsqueeze(0) * 50
                writer.add_image('train/Un_GroundTruth', labs, iter_num)

                image_l = net_input_l[1, 0:1, :, :]
                writer.add_image('train/L_Image', image_l, iter_num)
                outputs_l = torch.argmax(torch.softmax(out_l, dim=1), dim=1, keepdim=True)
                writer.add_image('train/L_Prediction', outputs_l[1, ...] * 50, iter_num)
                labs_l = l_label[1, ...].unsqueeze(0) * 50
                writer.add_image('train/L_GroundTruth', labs_l, iter_num)

            if iter_num > 0 and iter_num % 200 == 0:
                model.eval()
                metric_list = 0.0
                for _, sampled_batch in enumerate(valloader):
                    metric_i = test_single_volume(sampled_batch["image"], sampled_batch["label"], model, classes=num_classes)
                    metric_list += np.array(metric_i)
                metric_list = metric_list / len(db_val)
                for class_i in range(num_classes-1):
                    writer.add_scalar('info/val_{}_dice'.format(class_i+1), metric_list[class_i, 0], iter_num)
                    writer.add_scalar('info/val_{}_hd95'.format(class_i+1), metric_list[class_i, 1], iter_num)

                performance = np.mean(metric_list, axis=0)[0]
                writer.add_scalar('info/val_mean_dice', performance, iter_num)

                if performance > best_performance:
                    best_performance = performance
                    save_mode_path = os.path.join(snapshot_path, 'iter_{}_dice_{}.pth'.format(iter_num, round(best_performance, 4)))
                    save_best_path = os.path.join(snapshot_path,'{}_best_model.pth'.format(args.model))
                    torch.save(model.state_dict(), save_mode_path)
                    torch.save(model.state_dict(), save_best_path)

                logging.info('iteration %d : mean_dice : %f' % (iter_num, performance))
                model.train()

            if iter_num >= max_iterations:
                break
        if iter_num >= max_iterations:
            iterator.close()
            break
    writer.close()


In [7]:
pretrain_snapshot_path = 'modelBCP/pretrain'
selftrain_snapshot_path = 'modelBCP/selftrain'
for snapshot in [pretrain_snapshot_path, selftrain_snapshot_path]: 
    os.makedirs(snapshot, exist_ok= True)

for handler in logging.root.handlers[:]: 
    logging.root.removeHandler(handler)

logging.basicConfig(
    filename= 'modelBCP/log.txt',
    level= logging.INFO, 
    format= '[%(asctime)s] %(message)s', 
    datefmt= '%Y-%m-%d %H:%M:%S'
)
# Log out configuration 
logging.info("========== Experiment Configuration ==========")
logging.info(f"Pretrain Iterations  : {args.pretrain_iterations}")
logging.info(f"Selftrain Iterations : {args.selftrain_iterations}")
logging.info(f"Batch Size           : {args.batch_size}")
logging.info(f"Labeled Batch Size   : {args.labeled_bs}")
logging.info(f"Labelled Patients    : {args.label_num}")
logging.info(f"Learning Rate        : {args.base_lr}")
logging.info(f"Patch Size           : {args.patch_size}")
logging.info(f"BCP Weight           : {args.bcp_weight}")
logging.info(f"Reconstruction Weight: {args.recon_weight}")
logging.info("==============================================")

pretrain(args, pretrain_snapshot_path)
selftrain(args, pretrain_snapshot_path, selftrain_snapshot_path)

Mode: train: 1312 samples in total
Mode: val: 20 samples in total
Total slices is: 1312, labeled slices is: 136


 99%|████████████████████████████████▋| 90/91 [01:30<00:01,  1.01s/it]


Mode: train: 1312 samples in total
Mode: val: 20 samples in total
Total slices: 1312, Labeled slices: 136


 15%|████▉                            | 11/73 [00:30<02:50,  2.75s/it]


KeyboardInterrupt: 