In [1]:
import ImageTool.tool as tool

2025-03-05 10:27:12.623190: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-03-05 10:27:12.632421: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:479] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-03-05 10:27:12.644263: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:10575] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-03-05 10:27:12.644293: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1442] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-03-05 10:27:12.654201: I tensorflow/core/platform/cpu_feature_gua

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from efficientunet2 import get_efficientunet_b2
from models.Basic_module import Criterion
# from Basic_module import Criterion, Visualization
from ResNet3D import ResNet_appearance, ResNet_shape


class BayeSeg(nn.Module):
    def __init__(self, args):
        super(BayeSeg, self).__init__()

        self.args = args
        self.num_classes = args.num_classes

        self.res_shape = ResNet_shape(num_out_ch=2)
        self.res_appear = ResNet_appearance(num_out_ch=2, num_block=6, bn=True)
        self.unet = get_efficientunet_b2(      
            out_channels=2 * args.num_classes, pretrained=False
        )

        self.softmax = nn.Softmax(dim=1)

        Dx = torch.zeros([1, 1, 3, 3, 3], dtype=torch.float)
        Dx[0, 0, 1, 1, 1] = 1

        # 6-neighborhood in 3D (±x, ±y, ±z), each -1/6 for a simple Laplacian
        Dx[0, 0, 1, 1, 0] = -1/6  # left
        Dx[0, 0, 1, 1, 2] = -1/6  # right
        Dx[0, 0, 1, 0, 1] = -1/6  # up
        Dx[0, 0, 1, 2, 1] = -1/6  # down
        Dx[0, 0, 0, 1, 1] = -1/6  # front
        Dx[0, 0, 2, 1, 1] = -1/6
        # Dx[:, :, 1, 1] = 1
        # Dx[:, :, 1, 0] = Dx[:, :, 1, 2] = Dx[:, :, 0, 1] = Dx[:, :, 2, 1] = -1 / 4
        self.Dx = nn.Parameter(data=Dx, requires_grad=False)

    @staticmethod
    def sample_normal_jit(mu, log_var):
        sigma = torch.exp(log_var / 2)
        eps = mu.mul(0).normal_()
        z = eps.mul_(sigma).add_(mu)
        return z, eps

    def generate_m(self, samples):
        feature = self.res_appear(samples)
        mu_m, log_var_m = torch.chunk(feature, 2, dim=1)
        log_var_m = torch.clamp(log_var_m, -20, 0)
        m, _ = self.sample_normal_jit(mu_m, log_var_m)
        return m, mu_m, log_var_m

    def generate_x(self, samples):
        feature = self.res_shape(samples)
        mu_x, log_var_x = torch.chunk(feature, 2, dim=1)
        log_var_x = torch.clamp(log_var_x, -20, 0)
        x, _ = self.sample_normal_jit(mu_x, log_var_x)
        return x, mu_x, log_var_x

    def generate_z(self, x):
        feature = self.unet(x.repeat(1, 3, 1, 1, 1))
        mu_z, log_var_z = torch.chunk(feature, 2, dim=1)
        log_var_z = torch.clamp(log_var_z, -20, 0)
        z, _ = self.sample_normal_jit(mu_z, log_var_z)
        if self.training:
            return F.gumbel_softmax(z, dim=1), F.gumbel_softmax(mu_z, dim=1), log_var_z
        else:
            return self.softmax(z), self.softmax(mu_z), log_var_z

    def forward(self, samples: torch.Tensor):
        x, mu_x, log_var_x = self.generate_x(samples)
        m, mu_m, log_var_m = self.generate_m(samples)
        z, mu_z, log_var_z = self.generate_z(x)
        K = self.num_classes
        _, _, W, H, D = samples.shape

        residual = samples - (x + m)
        mu_rho_hat = (2 * self.args.gamma_rho + 1) / (
            residual * residual + 2 * self.args.phi_rho
        )
        mu_rho_hat = torch.clamp(mu_rho_hat, 1e4, 1e8)

        normalization = torch.sum(mu_rho_hat).detach()
        n, _ = self.sample_normal_jit(m, torch.log(1 / mu_rho_hat))

        # # Image line upsilon
        alpha_upsilon_hat = 2 * self.args.gamma_upsilon + K
        difference_x = F.conv3d(mu_x, self.Dx, padding=1)

        beta_upsilon_hat = (
            torch.sum(
                mu_z * (difference_x * difference_x + 2 * torch.exp(log_var_x)),
                dim=1,
                keepdim=True,
            )
            + 2 * self.args.phi_upsilon
        )  # B x 1 x W x H

        mu_upsilon_hat = alpha_upsilon_hat / beta_upsilon_hat
        # mu_upsilon_hat = torch.clamp(mu_upsilon_hat, 1e6, 1e10)

        # # Seg boundary omega
        difference_z = F.conv3d(
            mu_z, self.Dx.expand(K, 1, 3, 3, 3), padding=1, groups=K
        )  # B x K x W x H
        alpha_omega_hat = 2 * self.args.gamma_omega + 1
        pseudo_pi = torch.mean(mu_z, dim=(2, 3,4), keepdim=True)
        beta_omega_hat = (
            pseudo_pi * (difference_z * difference_z + 2 * torch.exp(log_var_z))
            + 2 * self.args.phi_omega
        )
        mu_omega_hat = alpha_omega_hat / beta_omega_hat
        mu_omega_hat = torch.clamp(mu_omega_hat, 1e2, 1e6)

        # # Seg category probability pi
        _, _, W, H, D = samples.shape
        alpha_pi_hat = self.args.alpha_pi + W * H * D / 2
        beta_pi_hat = (
            torch.sum(
                mu_omega_hat * (difference_z * difference_z + 2 * torch.exp(log_var_z)),
                dim=(2, 3, 4),
                keepdim=True,
            )
            / 2
            + self.args.beta_pi
        )
        digamma_pi = torch.special.digamma(
            alpha_pi_hat + beta_pi_hat
        ) - torch.special.digamma(beta_pi_hat)
        # # compute loss-related
        kl_y = residual * mu_rho_hat.detach() * residual

        kl_mu_z = torch.sum(
            digamma_pi.detach() * difference_z * mu_omega_hat.detach() * difference_z,
            dim=1,
        )
        kl_sigma_z = torch.sum(
            digamma_pi.detach()
            * (2 * torch.exp(log_var_z) * mu_omega_hat.detach() - log_var_z),
            dim=1,
        )

        kl_mu_x = torch.sum(
            difference_x * difference_x * mu_upsilon_hat.detach() * mu_z.detach(), dim=1
        )
        kl_sigma_x = (
            torch.sum(
                2 * torch.exp(log_var_x) * mu_upsilon_hat.detach() * mu_z.detach(),
                dim=1,
            )
            - log_var_x
        )

        kl_mu_m = self.args.sigma_0 * mu_m * mu_m
        kl_sigma_m = self.args.sigma_0 * torch.exp(log_var_m) - log_var_m

        visualize = {
            "shape": torch.concat([x, mu_x, torch.exp(log_var_x / 2)]),
            "appearance": torch.concat([n, m, 1 / mu_rho_hat.sqrt()]),
            "logit": torch.concat(
                [
                    z[:, 1:2, ...],
                    mu_z[:, 1:2, ...],
                    torch.exp(log_var_z / 2)[:, 1:2, ...],
                ]
            ),
            "shape_boundary": mu_upsilon_hat,
            "seg_boundary": mu_omega_hat[:, 1:2, ...],
        }

        pred = z if self.training else mu_z
        out = {
            "pred_masks": pred,
            "kl_y": kl_y,
            "kl_mu_z": kl_mu_z,
            "kl_sigma_z": kl_sigma_z,
            "kl_mu_x": kl_mu_x,
            "kl_sigma_x": kl_sigma_x,
            "kl_mu_m": kl_mu_m,
            "kl_sigma_m": kl_sigma_m,
            "normalization": normalization,
            "rho": mu_rho_hat,
            "omega": mu_omega_hat * digamma_pi,
            "upsilon": mu_upsilon_hat * mu_z,
            # "visualize": visualize,
        }
        return out
    
class BayeSeg_Criterion(Criterion):
    def __init__(self, args):
        super(BayeSeg_Criterion, self).__init__(args)
        self.bayes_loss_coef = args.bayes_loss_coef

    def loss_Bayes(self, outputs):
        N = outputs["normalization"]
        loss_y = torch.sum(outputs["kl_y"]) / N
        loss_mu_m = torch.sum(outputs["kl_mu_m"]) / N
        loss_sigma_m = torch.sum(outputs["kl_sigma_m"]) / N
        loss_mu_x = torch.sum(outputs["kl_mu_x"]) / N
        loss_sigma_x = torch.sum(outputs["kl_sigma_x"]) / N
        loss_mu_z = torch.sum(outputs["kl_mu_z"]) / N
        loss_sigma_z = torch.sum(outputs["kl_sigma_z"]) / N
        loss_Bayes = (
            loss_y
            + loss_mu_m
            + loss_sigma_m
            + loss_mu_x
            + loss_sigma_x
            + loss_mu_z
            + loss_sigma_z
        )

        return loss_Bayes

    def forward(self, pred, grnd):
        loss_dict = {
            "loss_Dice_CE": self.compute_dice_ce_loss(pred["pred_masks"], grnd),
            "Dice": self.compute_dice(pred["pred_masks"], grnd),
            "loss_Bayes": self.loss_Bayes(pred),
            "rho": torch.mean(pred["rho"]),
            "omega": torch.mean(pred["omega"]),
            "upsilon": torch.mean(pred["upsilon"]),
        }
        losses = (
            loss_dict["loss_Dice_CE"] + self.bayes_loss_coef * loss_dict["loss_Bayes"]
        )
        return losses, loss_dict

In [None]:
import torch.optim as optim
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class Args:
    def __init__(self):
        self.num_classes = 2
        # Hyperparameters referenced in BayeSeg forward pass
        self.gamma_rho = 1.0
        self.phi_rho = 1.0
        self.gamma_upsilon = 1.0
        self.phi_upsilon = 1.0
        self.gamma_omega = 1.0
        self.phi_omega = 1.0
        self.alpha_pi = 1.0
        self.beta_pi = 1.0
        self.sigma_0 = 1.0
        # If you have other hyperparameters or parameters, define them here
        self.bayes_loss_coef = 1.0 
        self.ce_loss_coef = 1.0 
        self.dice_loss_coef = 1.0 


args = Args()
model = BayeSeg(args).to(device)
my_loss = BayeSeg_Criterion(args)
# Switch to eval mode (or keep training mode)
model.eval()

# 3) Create a dummy input. Suppose 1 batch, 1 channel, 128x128 image
samples = torch.randint(0, 2, (1, 1, 160, 160, 64), dtype=torch.float32, device=device)

optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Dummy samples for demonstration
# shape: [batch_size=1, channels=1, depth=32, height=32, width=32]
80
for i in range(100):
    print(f"step {i} start!")
    # (a) Zero out the gradients from the previous iteration
    optimizer.zero_grad()

    # (b) Forward pass: generate predictions
    output = model(samples)
    
    # Make a random target for demonstration: shape matches the spatial size of "output" 
    # but "output" likely has shape [1, num_classes=2, D=32, H=32, W=32].
    target = torch.randint(0, 2, (1, 1, 160, 160, 64), dtype=torch.float32, device=device)  # no channel dimension for cross-entropy

    # (c) Compute the loss
    loss = my_loss(output, target)

    # (d) Backward pass: compute gradients
    loss[0].backward()

    # (e) Take an optimizer step to update model parameters
    optimizer.step()



step 0 start!


Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at ../aten/src/ATen/native/cudnn/Conv_v8.cpp:919.)
Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at ../aten/src/ATen/native/cudnn/Conv_v8.cpp:919.)


step 1 start!
step 2 start!
step 3 start!
step 4 start!
step 5 start!


KeyboardInterrupt: 

In [31]:
class BayeSeg(nn.Module):
    def __init__(self, args):
        super(BayeSeg, self).__init__()

        self.args = args
        self.num_classes = args.num_classes

        self.res_shape = ResNet_shape(num_out_ch=2)
        self.res_appear = ResNet_appearance(num_out_ch=2, num_block=6, bn=True)
        self.unet = get_efficientunet_b2(      
            out_channels=2 * args.num_classes, pretrained=False
        )

        self.softmax = nn.Softmax(dim=1)

        Dx = torch.zeros([1, 1, 3, 3], dtype=torch.float)
        Dx[:, :, 1, 1] = 1
        Dx[:, :, 1, 0] = Dx[:, :, 1, 2] = Dx[:, :, 0, 1] = Dx[:, :, 2, 1] = -1 / 4
        self.Dx = nn.Parameter(data=Dx, requires_grad=False)

    @staticmethod
    def sample_normal_jit(mu, log_var):
        sigma = torch.exp(log_var / 2)
        eps = mu.mul(0).normal_()
        z = eps.mul_(sigma).add_(mu)
        return z, eps

    def generate_m(self, samples):
        feature = self.res_appear(samples)
        mu_m, log_var_m = torch.chunk(feature, 2, dim=1)
        log_var_m = torch.clamp(log_var_m, -20, 0)
        m, _ = self.sample_normal_jit(mu_m, log_var_m)
        return m, mu_m, log_var_m

    def generate_x(self, samples):
        feature = self.res_shape(samples)
        mu_x, log_var_x = torch.chunk(feature, 2, dim=1)
        log_var_x = torch.clamp(log_var_x, -20, 0)
        x, _ = self.sample_normal_jit(mu_x, log_var_x)
        return x, mu_x, log_var_x

In [None]:
import math
import torch
from torch import nn as nn
from torch.nn import init as init
from torch.nn.modules.batchnorm import _BatchNorm


@torch.no_grad()
def default_init_weights(module_list, scale=1.0, bias_fill=0.0, **kwargs):
    """Initialize network weights.

    Args:
        module_list (list[nn.Module] | nn.Module): Modules to be initialized.
        scale (float): Scale initialized weights, especially for residual
            blocks. Default: 1.
        bias_fill (float): The value to fill bias. Default: 0
        kwargs (dict): Other arguments for initialization function.
    """
    if not isinstance(module_list, list):
        module_list = [module_list]
    for module in module_list:
        for m in module.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, **kwargs)
                m.weight.data *= scale
                if m.bias is not None:
                    m.bias.data.fill_(bias_fill)
            elif isinstance(m, nn.Linear):
                init.kaiming_normal_(m.weight, **kwargs)
                m.weight.data *= scale
                if m.bias is not None:
                    m.bias.data.fill_(bias_fill)
            elif isinstance(m, _BatchNorm):
                init.constant_(m.weight, 1)
                if m.bias is not None:
                    m.bias.data.fill_(bias_fill)


def make_layer(basic_block, num_basic_block, **kwarg):
    """Make layers by stacking the same blocks.

    Args:
        basic_block (nn.module): nn.module class for basic block.
        num_basic_block (int): number of blocks.

    Returns:
        nn.Sequential: Stacked blocks in nn.Sequential.
    """
    layers = []
    for _ in range(num_basic_block):
        layers.append(basic_block(**kwarg))
    return nn.Sequential(*layers)


def default_conv(in_channels, out_channels, kernel_size, strides=1, bias=True):
    return nn.Conv3d(
        in_channels,
        out_channels,
        kernel_size,
        strides,
        padding=(kernel_size // 2),
        bias=bias,
    )


class ResBlock(nn.Module):
    def __init__(
        self,
        conv=default_conv,
        n_feats=64,
        kernel_size=3,
        bias=True,
        bn=False,
        act=nn.ReLU(True),
        res_scale=1,
    ):
        super(ResBlock, self).__init__()
        m = []
        for i in range(2):
            m.append(conv(n_feats, n_feats, kernel_size, bias=bias))
            if bn:
                m.append(nn.BatchNorm3d(n_feats))
            if i == 0:
                m.append(act)

        self.body = nn.Sequential(*m)
        self.res_scale = res_scale

    def forward(self, x):
        res = self.body(x).mul(self.res_scale)
        res += x

        return res


class Upsample(nn.Sequential):
    """Upsample module.

    Args:
        scale (int): Scale factor. Supported scales: 2^n and 3.
        num_feat (int): Channel number of intermediate features.
    """

    def __init__(self, scale, num_feat):
        m = []
        if (scale & (scale - 1)) == 0:  # scale = 2^n
            for _ in range(int(math.log(scale, 2))):
                m.append(nn.Conv3d(num_feat, 4 * num_feat, 3, 1, 1))
                m.append(nn.PixelShuffle(2))
        elif scale == 3:
            m.append(nn.Conv3d(num_feat, 9 * num_feat, 3, 1, 1))
            m.append(nn.PixelShuffle(3))
        else:
            raise ValueError(
                f"scale {scale} is not supported. " "Supported scales: 2^n and 3."
            )
        super(Upsample, self).__init__(*m)


class ResNet_appearance(nn.Module):
    def __init__(self, num_in_ch=1, num_out_ch=1, num_feat=64, num_block=10, bn=False):
        super(ResNet_appearance, self).__init__()
        self.conv_first = nn.Conv3d(num_in_ch, num_feat, 3, 1, 1)
        self.body = make_layer(ResBlock, num_block, n_feats=num_feat, bn=bn)
        self.conv_last = nn.Conv3d(num_feat, num_out_ch, 3, 1, 1)

        # activation function
        self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)

        # initialization
        default_init_weights([self.conv_first, self.conv_last], 0.1)

    def forward(self, x):
        feat = self.lrelu(self.conv_first(x))
        out = self.body(feat)
        out = self.conv_last(self.lrelu(out))
        out += x
        return out


class ResNet_shape(nn.Module):
    def __init__(self, num_in_ch=1, num_out_ch=1, num_feat=64, num_block=10, bn=False):
        super(ResNet_shape, self).__init__()
        self.conv_first = nn.Conv3d(num_in_ch, num_feat, 3, 1, 1)
        self.body = make_layer(ResBlock, num_block, n_feats=num_feat, bn=bn)
        self.conv_last = nn.Conv3d(num_feat, num_out_ch, 3, 1, 1)

        # activation function
        self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)

        # initialization
        default_init_weights([self.conv_first, self.conv_last], 0.1)

    def forward(self, x):
        feat = self.lrelu(self.conv_first(x))
        out = self.body(feat)
        out = self.conv_last(self.lrelu(out))
        return out


In [9]:
model = ResNet_shape(num_out_ch=2)
samples = torch.randn(1, 1, 128, 128, 128)

# 4) Forward pass
with torch.no_grad():
    output = model(samples)
    print(output.shape)

torch.Size([1, 2, 128, 128, 128])


In [11]:
import os
import shutil
import random

# Define dataset directory (update this path as needed)
dataset_dir = r"/home/molloi-lab-linux2/Desktop/BayeSeg/dataset/imageCAS3d/aaa"
out_dir = r"/home/molloi-lab-linux2/Desktop/BayeSeg/dataset/imageCAS3d/crop"

# Define output directories
output_dirs = {
    "train": os.path.join(out_dir, "train"),
    "val": os.path.join(out_dir, "val"),
    "test": os.path.join(out_dir, "test")
}

# # Create output directories if they do not exist
# for split in output_dirs.values():
#     os.makedirs(split, exist_ok=True)

# Step 1: Find all unique cases (without "_Segmentation")
all_files = os.listdir(dataset_dir)
all_cases = set()

for file_name in all_files:
    if file_name.endswith(".nii.gz") and "_Segmentation" not in file_name:
        case_id = file_name.replace(".nii.gz", "")  # Extract the unique ID (before _Segmentation)
        all_cases.add(case_id)

# Convert set to list and shuffle
all_cases = list(all_cases)
random.shuffle(all_cases)

# Step 2: Define split sizes
train_size = 800
val_size = 40
test_size = 160

# Ensure we have enough cases
assert len(all_cases) >= train_size + val_size + test_size, "Not enough cases for the requested split!"

# Step 3: Split the dataset
train_cases = all_cases[:train_size]
val_cases = all_cases[train_size:train_size + val_size]
test_cases = all_cases[train_size + val_size:train_size + val_size + test_size]

# Function to move paired files (image + segmentation) to the respective split folders
def move_case(case_id, destination):
    for suffix in ["", "_Segmentation"]:  # Handle both image and segmentation files
        file_name = f"{case_id}{suffix}.nii.gz"
        src_path = os.path.join(dataset_dir, file_name)
        dest_path = os.path.join(destination, file_name)
        if os.path.exists(src_path):  # Check if file exists before moving
            shutil.move(src_path, dest_path)

# Step 4: Move files into train, val, and test folders
for case in train_cases:
    move_case(case, output_dirs["train"])

for case in val_cases:
    move_case(case, output_dirs["val"])

for case in test_cases:
    move_case(case, output_dirs["test"])

print(f"Dataset successfully split into:\nTrain: {len(train_cases)} cases\nValidation: {len(val_cases)} cases\nTest: {len(test_cases)} cases")


Dataset successfully split into:
Train: 800 cases
Validation: 40 cases
Test: 160 cases
