In [1]:
import warnings
warnings.filterwarnings('ignore')

In [2]:
import numpy as np 
import matplotlib.pyplot as plt 
import os 
import logging
import random
from tqdm import tqdm

import torch 
import torch.nn as nn 
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from torchvision.datasets import MNIST


# module 
from networks.net_factory import BCP_net
from dataset.basedataset import ACDCDataset
from dataset.utils import RandomGenerator, ACDC_patients_to_slices, TwoStreamBatchSampler
from utils.params import params 
from utils.masks import generate_mask, random_mask, contact_mask
from utils.losses import mix_loss
from utils.valid2d import test_single_volume
from networks.utils import save_net_opt, load_net_opt, load_net


# writer 
from torch.utils.tensorboard import SummaryWriter

2025-04-22 22:19:24.089056: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-04-22 22:19:24.103038: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-04-22 22:19:24.106813: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-04-22 22:19:24.118160: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:
args = params() 

#### 1. BCP network

In [4]:
# pre-train BCP
def pretrain(args, snapshot_path): 
    base_lr = args.base_lr
    num_classes = args.num_classes
    max_iterations = args.pretrain_iterations
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 
    labeled_bs = args.labeled_bs 

    # Load data 
    def worker_init_fn(worker_id):
        random.seed(args.seed + worker_id) 

    label_sub_bs, unlabeled_sub_bs = int(args.labeled_bs / 2), int((args.batch_size - args.labeled_bs)/2) 
    db_train = ACDCDataset(base_dir= args.root_dir, split= 'train', 
                        transform= transforms.Compose([RandomGenerator(args.patch_size)]))
    db_val = ACDCDataset(base_dir= args.root_dir, split= 'val')
    total_slices = len(db_train)
    labeled_slices = ACDC_patients_to_slices(args.root_dir, args.label_num)
    print(f'Total slices is: {total_slices}, labeled slices is: {labeled_slices}')
    labeled_idxes = list(range(0, labeled_slices))
    unlabeled_idxes = list(range(labeled_slices, total_slices))
    batch_sampler = TwoStreamBatchSampler(labeled_idxes, unlabeled_idxes, args.batch_size, args.batch_size - args.labeled_bs)

    trainloader = DataLoader(db_train, batch_sampler= batch_sampler, num_workers= 4, pin_memory= True, worker_init_fn= worker_init_fn)
    valloader = DataLoader(db_val, batch_size=1, shuffle= False, num_workers= 1)

    # Model 
    model = BCP_net(in_chns=1, num_classes= num_classes)
    optimizer = optim.SGD(model.parameters(), lr= base_lr, momentum= 0.9, weight_decay= 0.0001)

    writer = SummaryWriter(snapshot_path + '/log')
    logging.info('Start pre_training')
    logging.info(f'{len(trainloader)} iterations per epoch')

    model.train() 
    iter_num = 0 
    max_epoch = max_iterations // len(trainloader) + 1 
    best_performance = 0.0 
    best_hd = 100 
    iterator = tqdm(range(max_epoch), ncols= 70) 
    for _ in iterator: 
        for _, sampled_batch in enumerate(trainloader): 
            volume_batch, label_batch = sampled_batch['image'], sampled_batch['label']
            volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda() 

            # Create BCP image 
            img_a, img_b = volume_batch[: label_sub_bs], volume_batch[label_sub_bs : labeled_bs]
            lab_a, lab_b = label_batch[: label_sub_bs], label_batch[label_sub_bs: labeled_bs]
            img_mask, loss_mask =  generate_mask(img_a)
            gt_mixl = lab_a * img_mask + lab_b * (1 - img_mask)

            # Feed to model 
            net_input = img_a * img_mask + img_b * (1 - img_mask)
            out_mixl = model(net_input)
            loss_dice, loss_ce = mix_loss(out_mixl, lab_a, lab_b, loss_mask, u_weight= 1.0, unlab= True)

            loss = (loss_dice + loss_ce ) / 2 
            optimizer.zero_grad() 
            loss.backward() 
            optimizer.step() 
            iter_num += 1 

            writer.add_scalar('info/total loss', loss, iter_num)
            writer.add_scalar('info/mix_dice', loss_dice, iter_num)
            writer.add_scalar('info/mix_ce', loss_ce, iter_num)

            logging.info(f'iteration: {iter_num}, loss: {loss}, mix_dice: {loss_dice}, mix_ce: {loss_ce}')
            if iter_num % 20 == 0:
                image = net_input[1, 0:1, :, :]  # shape = [1, H, W]
                image = (image - image.min()) / (image.max() - image.min() + 1e-5)  # normalize to [0,1]
                writer.add_image('pre_train/Mixed_Image', image, iter_num)
                outputs = torch.argmax(torch.softmax(out_mixl, dim=1), dim=1, keepdim=True)
                writer.add_image('pre_train/Mixed_Prediction', outputs[1, ...] * 50, iter_num)
                labs = gt_mixl[1, ...].unsqueeze(0) * 50
                writer.add_image('pre_train/Mixed_GroundTruth', labs, iter_num)

            if iter_num > 0 and iter_num % 200 == 0:
                model.eval()
                metric_list = 0.0
                for _, sampled_batch in enumerate(valloader):
                    metric_i = test_single_volume(sampled_batch["image"], sampled_batch["label"], model, classes=num_classes)
                    metric_list += np.array(metric_i)
                metric_list = metric_list / len(db_val)
                for class_i in range(num_classes-1):
                    writer.add_scalar('info/val_{}_dice'.format(class_i+1), metric_list[class_i, 0], iter_num)
                    writer.add_scalar('info/val_{}_hd95'.format(class_i+1), metric_list[class_i, 1], iter_num)

                performance = np.mean(metric_list, axis=0)[0]
                writer.add_scalar('info/val_mean_dice', performance, iter_num)

                if performance > best_performance:
                    best_performance = performance
                    save_mode_path = os.path.join(snapshot_path, 'iter_{}_dice_{}.pth'.format(iter_num, round(best_performance, 4)))
                    save_best_path = os.path.join(snapshot_path,'{}_best_model.pth'.format(args.model))
                    save_net_opt(model, optimizer, save_mode_path)
                    save_net_opt(model, optimizer, save_best_path)

                logging.info('iteration %d : mean_dice : %f' % (iter_num, performance))
                model.train()

            if iter_num >= max_iterations:
                break
        if iter_num >= max_iterations:
            iterator.close()
            break
    writer.close()


In [None]:
def selftrain(args, pretrain_snapshot_path, selftrain_snapshot_path): 
    pass 

In [5]:
pretrain_snapshot_path = 'modelBCP/pretrain'
selftrain_snapshot_path = 'modelBCP/selftrain'
for snapshot in [pretrain_snapshot_path, selftrain_snapshot_path]: 
    os.makedirs(snapshot, exist_ok= True)

for handler in logging.root.handlers[:]: 
    logging.root.removeHandler(handler)

logging.basicConfig(
    filename= 'modelBCP/log.txt',
    level= logging.INFO, 
    format= '[%(asctime)s] %(message)s', 
    datefmt= '%Y-%m-%d %H:%M:%S'
)

pretrain(args, pretrain_snapshot_path)

Mode: train: 1312 samples in total
Mode: val: 20 samples in total
Total slices is: 1312, labeled slices is: 136


 99%|██████████████████████████████▊| 181/182 [03:22<00:01,  1.12s/it]
