## Prepare folder

In [None]:
# Download SKM-TEA from main branch on GitHub
!pip install --upgrade pytorch-lightning==1.7.7 skm-tea

In [None]:
import os
from pprint import pprint

import numpy as np
import torch
import h5py
import matplotlib.pyplot as plt
from skimage.color import label2rgb
import pandas as pd
from torch import nn

import dosma as dm

import meddlr.ops as oF
from meddlr.data import DatasetCatalog, MetadataCatalog
from meddlr.utils.logger import setup_logger
from meddlr.utils import env

import skm_tea as st
# import tqdm
from tqdm import tqdm

from meddlr.data.data_utils import collect_mask

from typing import Union, Sequence

import os,sys
import argparse
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import json

sys.path.append(os.getcwd)
print(sys.path)


if torch.cuda.is_available():
  DEVICE = torch.device("cuda")
else:
  DEVICE = torch.device("cpu")

print("Device: ", DEVICE)

torch.cuda.set_device(1) 


image_names = os.listdir('./v1-release/image_files')
folder_path = './v1-release/image_files'

image_names_train = image_names[:5]
image_names_test = image_names[6:]

print("train:", image_names_train, end='  ')

## Load Function

### plot function

In [2]:
def get_scaled_image(
        x: Union[torch.Tensor, np.ndarray], percentile=0.99, clip=False
):
    """Scales image by intensity percentile (and optionally clips to [0, 1]).

    Args:
      x (torch.Tensor | np.ndarray): The image to process.
      percentile (float): The percentile of magnitude to scale by.
      clip (bool): If True, clip values between [0, 1]

    Returns:
      torch.Tensor | np.ndarray: The scaled image.
    """
    is_numpy = isinstance(x, np.ndarray)
    if is_numpy:
        x = torch.as_tensor(x)

    scale_factor = torch.quantile(x, percentile)
    x = x / scale_factor
    if clip:
        x = torch.clip(x, 0, 1)

    if is_numpy:
        x = x.numpy()

    return x


def plot_images(
        images, processor=None, disable_ticks=True, titles: Sequence[str] = None,
        ylabel: str = None, xlabels: Sequence[str] = None, cmap: str = "gray",
        show_cbar: bool = False, overlay=None, opacity: float = 0.3,
        hsize=5, wsize=5, axs=None
):
    """Plot multiple images in a single row.

    Add an overlay with the `overlay=` argument.
    Add a colorbar with `show_cbar=True`.
    """

    def get_default_values(x, default=""):
        if x is None:
            return [default] * len(images)
        return x

    titles = get_default_values(titles)
    ylabels = get_default_values(images)
    xlabels = get_default_values(xlabels)

    N = len(images)
    if axs is None:
        fig, axs = plt.subplots(1, N, figsize=(wsize * N, hsize))
    else:
        assert len(axs) >= N
        fig = axs.flatten()[0].get_figure()

    for ax, img, title, xlabel in zip(axs, images, titles, xlabels):
        if processor is not None:
            img = processor(img)
        im = ax.imshow(img, cmap=cmap)
        ax.set_title(title)
        ax.set_xlabel(xlabel)

    if overlay is not None:
        for ax in axs.flatten():
            im = ax.imshow(overlay, alpha=opacity)

    if show_cbar:
        fig.subplots_adjust(right=0.8)
        cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
        fig.colorbar(im, cax=cbar_ax)

    if disable_ticks:
        for ax in axs.flatten():
            ax.get_xaxis().set_ticks([])
            ax.get_yaxis().set_ticks([])

    return axs


### Load loss function

In [3]:
import torch.nn as nn
import torch.nn.functional as F



def CE_Loss(inputs, target, cls_weights, num_classes=21):
    """"""
    # n, c, d, h, w = inputs.size()  # V-Net
    n, c, h, w = inputs.size()       # U-Net
    nt, ht, wt = target.size()
    if h != ht or w != wt:
        inputs = F.interpolate(inputs, size=(ht, wt), mode='bilinear', align_corners=True)

    # temp_inputs = inputs.transpose(1, 2).transpose(2, 3).transpose(3, 4).contiguous().view(-1, c)   # V-Net
    temp_inputs = inputs.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)                     # U-Net
    temp_target = target.contiguous().view(-1)

    CE_loss = nn.CrossEntropyLoss(weight=cls_weights, ignore_index=num_classes)(temp_inputs, temp_target)
    
    return CE_loss

def BCE_Loss(inputs, target, vnet=False):
    """"""
    if vnet:
        n, c, d, h, w = inputs.size()  # V-Net
    else:
        n, c, h, w = inputs.size()       # U-Net
    nt, ht, wt ,ct= target.size()
    if h != ht or w != wt:
        inputs = F.interpolate(inputs, size=(ht, wt), mode='bilinear', align_corners=True)

    if vnet:
        temp_inputs = inputs.transpose(1, 2).transpose(2, 3).transpose(3, 4).contiguous().view(-1, c)     # V-Net
    else:
        temp_inputs = inputs.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)                     # U-Net
    temp_target = target.contiguous().view(-1, ct).float()

    BCE_loss = nn.BCEWithLogitsLoss()(temp_inputs, temp_target)
     
    return BCE_loss

def Dice_Loss(inputs, target, beta=1, smooth=1e-5):
    """
        Dice loss: 1 - 2TP / (2TP + FP + FN)
    """
    n, c, h, w = inputs.size()
    nt, ht, wt, ct = target.size()
    if h != ht or w != wt:
        inputs = F.interpolate(inputs, size=(ht, wt), mode='bilinear', align_corners=True)

    # temp_inputs = torch.softmax(inputs.transpose(1, 2).transpose(2, 3).contiguous().view(n, -1, c), dim=-1)
    temp_inputs = torch.sigmoid(inputs.transpose(1, 2).transpose(2, 3).contiguous().view(n, -1, c))
    temp_target = target.view(n, -1, ct)

    tp = torch.sum(temp_inputs * temp_target, dim=(0, 1))
    fp = torch.sum(temp_inputs, dim=(0, 1)) - tp
    fn = torch.sum(temp_target, dim=(0, 1)) - tp

    score = ((1 + beta ** 2) * tp + smooth) / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + smooth)
    dice_loss = 1 - torch.mean(score)
    return dice_loss


def weights_init(net, init_type='normal', init_gain=0.02):
    def init_func(m):
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and classname.find('Conv') != -1:
            if init_type == 'normal':
                torch.nn.init.normal_(m.weight.data, 0.0, init_gain)
            elif init_type == 'xavier':
                torch.nn.init.xavier_normal_(m.weight.data, gain=init_gain)
            elif init_type == 'kaiming':
                torch.nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
            elif init_type == 'orthogonal':
                torch.nn.init.orthogonal_(m.weight.data, gain=init_gain)
            else:
                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
        elif classname.find('BatchNorm2d') != -1:
            torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
            torch.nn.init.constant_(m.bias.data, 0.0)
    print('initialize network with %s type' % init_type)
    net.apply(init_func)

## Load Model

### U-Net Model

In [21]:
class UNet_32(nn.Module):
    def __init__(self):
        super(UNet_32, self).__init__()
        
        # Contracting Path (Encoder)
        self.enc_conv0 = self.conv_stage(1, 32)  # New initial layer with 32 channels
        self.enc_conv1 = self.conv_stage(32, 64)
        self.enc_conv2 = self.conv_stage(64, 128)
        self.enc_conv3 = self.conv_stage(128, 256)
        self.enc_conv4 = self.conv_stage(256, 512)
        self.enc_conv5 = self.conv_stage(512, 1024)

        # Expanding Path (Decoder)
        self.dec_conv4 = self.conv_stage(1024 , 512)
        self.dec_conv3 = self.conv_stage(512, 256)
        self.dec_conv2 = self.conv_stage(256 , 128)
        self.dec_conv1 = self.conv_stage(128 , 64)
        self.dec_conv0 = self.conv_stage(64 , 32)  # New final layer with 32 channels

        self.upconv4 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        self.upconv3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.upconv2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.upconv1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.upconv0 = nn.ConvTranspose2d(64, 32, 2, stride=2)  # New upsampling layer

        self.ReLU = nn.ReLU()
        self.BatchNorm2d4 = nn.BatchNorm2d(512)
        self.BatchNorm2d3 = nn.BatchNorm2d(256)
        self.BatchNorm2d2 = nn.BatchNorm2d(128)
        self.BatchNorm2d1 = nn.BatchNorm2d(64)
        self.BatchNorm2d0 = nn.BatchNorm2d(32)

        # Final Output
        self.final_conv = nn.Conv2d(32, 4, 1)

    def conv_stage(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            # nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(out_channels)
        )

    def forward(self, x):
        # Encoder
        enc0 = self.enc_conv0(x)
        x = F.max_pool2d(enc0, 2)
        enc1 = self.enc_conv1(x)
        x = F.max_pool2d(enc1, 2)
        enc2 = self.enc_conv2(x)
        x = F.max_pool2d(enc2, 2)
        enc3 = self.enc_conv3(x)
        x = F.max_pool2d(enc3, 2)
        enc4 = self.enc_conv4(x)
        x = F.max_pool2d(enc4, 2)
        x = self.enc_conv5(x)

        # Decoder
        x = self.upconv4(x)
        x = self.ReLU(x)
        x = self.BatchNorm2d4(x) 
        x = torch.cat((x, enc4), dim=1)
        x = self.dec_conv4(x)

        x = self.upconv3(x)
        x = self.ReLU(x)
        x = self.BatchNorm2d3(x) 
        x = torch.cat((x, enc3), dim=1)
        x = self.dec_conv3(x)

        x = self.upconv2(x)
        x = self.ReLU(x)
        x = self.BatchNorm2d2(x) 
        x = torch.cat((x, enc2), dim=1)
        x = self.dec_conv2(x)

        x = self.upconv1(x)
        x = self.ReLU(x)
        x = self.BatchNorm2d1(x) 
        x = torch.cat((x, enc1), dim=1)
        x = self.dec_conv1(x)

        x = self.upconv0(x)
        x = self.ReLU(x)
        x = self.BatchNorm2d0(x)
        x = torch.cat((x, enc0), dim=1)
        x = self.dec_conv0(x)

        x = self.final_conv(x)
        return x

### V-Net Model

In [48]:
class VNet_80(nn.Module):
    def __init__(self):
        super(VNet_80, self).__init__()
        self.down1 = self._conv_block(1, 16)
        self.down2 = self._conv_block(16, 32)
        self.down3 = self._conv_block(32, 64)
        self.down4 = self._conv_block(64, 128)
        # No down5 due to reduced depth
        
        self.up4 = self._up_block(128, 64)
        self.up3 = self._up_block(64, 32)
        self.up2 = self._up_block(32, 16)
        
        self.final_conv = nn.Conv3d(16, 4, kernel_size=1)

    def _conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True),
        )
    
    def _up_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.ConvTranspose3d(in_channels, out_channels, kernel_size=2, stride=2),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        # Encoder
        x1 = self.down1(x)
        x2 = self.down2(F.max_pool3d(x1, 2))
        x3 = self.down3(F.max_pool3d(x2, 2))
        x4 = self.down4(F.max_pool3d(x3, 2))
        
        # Decoder
        x = self.up4(x4)
        x = torch.add(x, x3)
        x = self.up3(x)
        x = torch.add(x, x2)
        x = self.up2(x)
        x = torch.add(x, x1)
        
        x = self.final_conv(x)
        return x

### Transformer-Net Model

In [4]:
import torch
from torch import nn

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# helpers

def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# classes

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.norm = nn.LayerNorm(dim)

        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        x = self.norm(x)

        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
                FeedForward(dim, mlp_dim, dropout = dropout)
            ]))

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x

        return self.norm(x)

# [16,16,1024]  -> [256, 1024]

# dim = 1024
# depth = 12
# heads = 8
# dim_head = 64
# mlp_dim = 150
# dropout = 0.2
# Utransformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)


class TransNet(nn.Module):
    def __init__(self):
        super(TransNet, self).__init__()
        
        # Contracting Path (Encoder)
        self.enc_conv0 = self.conv_stage(1, 32)  # New initial layer with 32 channels
        self.enc_conv1 = self.conv_stage(32, 64)
        self.enc_conv2 = self.conv_stage(64, 128)
        self.enc_conv3 = self.conv_stage(128, 256)
        self.enc_conv4 = self.conv_stage(256, 512)
        self.enc_conv5 = self.conv_stage(512, 1024)

        # Expanding Path (Decoder)
        self.dec_conv4 = self.conv_stage(1024 , 512)
        self.dec_conv3 = self.conv_stage(512, 256)
        self.dec_conv2 = self.conv_stage(256 , 128)
        self.dec_conv1 = self.conv_stage(128 , 64)
        self.dec_conv0 = self.conv_stage(64 , 32)  # New final layer with 32 channels

        self.upconv4 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        self.upconv3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.upconv2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.upconv1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.upconv0 = nn.ConvTranspose2d(64, 32, 2, stride=2)  # New upsampling layer

        self.ReLU = nn.ReLU()
        self.BatchNorm2d4 = nn.BatchNorm2d(512)
        self.BatchNorm2d3 = nn.BatchNorm2d(256)
        self.BatchNorm2d2 = nn.BatchNorm2d(128)
        self.BatchNorm2d1 = nn.BatchNorm2d(64)
        self.BatchNorm2d0 = nn.BatchNorm2d(32)

        # Final Output
        self.final_conv = nn.Conv2d(32, 4, 1)

        self.to_patch_embedding = nn.Sequential(
            nn.LayerNorm(1024),
            # nn.Linear(1024, 768),
            # nn.LayerNorm(768),
        )

        dim = 1024
        depth = 8
        heads = 6
        dim_head = 32
        mlp_dim = 768
        dropout = 0.2
        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

    def conv_stage(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            # nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(out_channels)
        )

    def forward(self, x):
        # Encoder [512,512,1]
        enc0 = self.enc_conv0(x)    #[512,512,32]
        x = F.max_pool2d(enc0, 2)   #[256,256,32]
        enc1 = self.enc_conv1(x)    #[256,256,64]
        x = F.max_pool2d(enc1, 2)   #[128,128,64]
        enc2 = self.enc_conv2(x)    #[128,128,128]

        x = F.max_pool2d(enc2, 2)   #[64,64,128]
        enc3 = self.enc_conv3(x)
        x = F.max_pool2d(enc3, 2)
        enc4 = self.enc_conv4(x)
        x = F.max_pool2d(enc4, 2)
        x = self.enc_conv5(x)
        # x = x.reshape(x.shape[0], 16, 16, 1024)
        x = x.permute(0, 2, 3, 1)  
        # print(x.shape)

        # Transformer
        x = x.view(x.shape[0], 256, 1024) 
        x = self.to_patch_embedding(x)
        x = self.transformer(x)
        x = x.view(x.shape[0], 16, 16, 1024) 
        x = x.permute(0, 3, 1, 2)
        # print(x.shape)

        # Decoder
        x = self.upconv4(x)
        x = self.ReLU(x)
        x = self.BatchNorm2d4(x) 
        x = torch.cat((x, enc4), dim=1)
        x = self.dec_conv4(x)

        x = self.upconv3(x)
        x = self.ReLU(x)
        x = self.BatchNorm2d3(x) 
        x = torch.cat((x, enc3), dim=1)
        x = self.dec_conv3(x)

        x = self.upconv2(x)
        x = self.ReLU(x)
        x = self.BatchNorm2d2(x) 
        x = torch.cat((x, enc2), dim=1)
        x = self.dec_conv2(x)

        x = self.upconv1(x)
        x = self.ReLU(x)
        x = self.BatchNorm2d1(x) 
        x = torch.cat((x, enc1), dim=1)
        x = self.dec_conv1(x)

        x = self.upconv0(x)
        x = self.ReLU(x)
        x = self.BatchNorm2d0(x)
        x = torch.cat((x, enc0), dim=1)
        x = self.dec_conv0(x)

        x = self.final_conv(x)
        return x


## Train Model

### Train U-Net

Prapare dataset

In [None]:
echo1_o = []
segmentations_o = []

for h5_file in tqdm(image_names_train,desc='process dataset: '):
# 拼接完整的文件路径
    image_file = os.path.join(folder_path, h5_file)
    with h5py.File(image_file, "r") as f:
        echo1_f = f["echo1"][()]  # Shape: (x, y, z)
        echo2 = f["echo2"][()]  # Shape: (x, y, z)
        segmentations_f = f["seg"][()]  # Shape: (x, y, z, #classes)
        
    # print(echo1_f.shape[2])
    # if echo1_f.shape[2] == 160:
    echo1_o.append(echo1_f)
    segmentations_o.append(segmentations_f)


echo1_o = np.concatenate(echo1_o,axis=2)
segmentations_o = np.concatenate(segmentations_o,axis=2)

Set args

In [35]:
num_classes = 4
batch_size = 40
sum_epoch = 100
pretrained = False
lr = 1e-3
CHECKPOINT = False
dice_loss = False
save_modelpath = "./save_model/"
save_checkPath = './checkPoint_model_bce/'
save_path = os.path.join(save_modelpath, 'unet32_model_100_BCE.pth')
scalar_name = "aver_Loss_bce"

Load model

In [None]:
model = UNet_32()

if not pretrained:
    weights_init(model)
else:
    model.load_state_dict(torch.load(os.path.join('./save_model/unet32_model_100.pth'),map_location=DEVICE))
    print("Load model success! ")

print(model)
print('------------------')
# print_model_info(model)
model = model.cuda()

optimizer = optim.Adam(model.parameters(), lr, weight_decay=5e-3)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.94)
start_epoch = 0

if CHECKPOINT:
    path_checkpoint = os.path.join(save_checkPath, 'ckpt_best_120.pth')
    checkpoint = torch.load(path_checkpoint)
    start_epoch = checkpoint['epoch']
    model.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])


Train Epochs

In [None]:
total = echo1_o.shape[2]
x_chunks = np.array_split(echo1_o, total / batch_size, axis=2)
segmentation_chunks = np.array_split(segmentations_o, total / batch_size, axis=2)

tensorboardPath = "./runs"
if not os.path.isdir(tensorboardPath):
    os.mkdir(tensorboardPath)
writer = SummaryWriter(log_dir=tensorboardPath)

for epoch in range(start_epoch, sum_epoch):
    total_loss = 0

    model.train()
    print('Start Train')

    bat = 0
    with tqdm(total=len(x_chunks), desc=f'Epoch {epoch+1}/{sum_epoch}', postfix=dict) as pbar:
        for chunk in x_chunks:

            optimizer.zero_grad()
            with torch.no_grad():
                
                chunk =  torch.as_tensor(chunk).float().to(DEVICE)
                chunk = (chunk - chunk.mean()) / chunk.std()

                chunk = chunk.permute(2, 0, 1)  #(B,H,W) (40,512,512)
                chunk = chunk.unsqueeze(1)      # add a channel dimension (B,C,H,W) (40,1,512,512)
                
                segmentation = segmentation_chunks[bat]                                                            #[512,512,40,6]
                segmentation = collect_mask(segmentation, (0, 1, (2, 3), (4, 5)), out_channel_first=False)         #[512,512,40,4]
                # segmentation = oF.one_hot_to_categorical(segmentation, channel_dim=-1)   #label  [1 1 1 ... 2 2 2]   [512,512,40]

                segmentation =  torch.as_tensor(segmentation).to(DEVICE)
                segmentation = segmentation.permute(2, 0, 1, 3)             #[40,512,512,4]

            pre_out = model(chunk)                                          #[40,4,512,512]
            # pre_out = oF.pred_to_categorical(pre, activation='sigmoid')   # [40,512,512]

            if dice_loss:
                loss = Dice_Loss(pre_out, segmentation)
            else:
                loss = BCE_Loss(pre_out, segmentation)


            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            
            bat = bat + 1
            pbar.set_postfix(**{
            'cur_loss': loss.item(),
            'averange_loss': total_loss / bat
            })
            
            pbar.update(1)
    
    lr_scheduler.step()
    writer.add_scalars(scalar_name, {"Train": total_loss / bat }, epoch)

    if (epoch+1) % 20 == 0:
        
        checkpoint = {
            'epoch':(epoch+1),
            'model':model.state_dict(),
            'optimizer':optimizer.state_dict(),
            'lr_scheduler':lr_scheduler.state_dict()}
        if not os.path.isdir(save_checkPath):
            os.mkdir(save_checkPath)
        torch.save(checkpoint,save_checkPath+'/ckpt_best_%s.pth'%(str(epoch+1)))

# Save  
# save_path = os.path.join(save_modelpath, 'unet32_model_200_BCE.pth')
if not os.path.isdir(save_path):
    os.mkdir(save_path)
torch.save(model.state_dict(), save_path)

### Train V-Net

Prapare dataset

In [None]:
echo1_o = []
segmentations_o = []

for h5_file in tqdm(image_names_train,desc='process dataset: '):

    image_file = os.path.join(folder_path, h5_file)
    with h5py.File(image_file, "r") as f:
        echo1_f = f["echo1"][()]  # Shape: (x, y, z)
        echo2 = f["echo2"][()]  # Shape: (x, y, z)
        segmentations_f = f["seg"][()]  # Shape: (x, y, z, #classes)
        
    # print(echo1_f.shape[2])
    if echo1_f.shape[2] == 160:
        echo1_o.append(echo1_f)
        segmentations_o.append(segmentations_f)

Set args

In [58]:
num_classes = 4
batch_size = 40
sum_epoch = 100
pretrained = False
lr = 1e-3
CHECKPOINT = False
dice_loss = False
save_modelpath = "./save_model/"
if not os.path.isdir(save_modelpath):
    os.mkdir(save_modelpath)
save_checkPath = './checkPoint_Vmodel_bce/'
save_path = os.path.join(save_modelpath, 'Vnet80_model_100_BCE.pth')
scalar_name = "Vnet_aver_Loss_bce"

Load model

In [None]:
model = VNet_80()

if not pretrained:
    weights_init(model)
else:
    model.load_state_dict(torch.load(os.path.join('./save_model/Vnet80_model_100.pth'),map_location=DEVICE))
    print("Load model success! ")

print(model)
print('------------------')
# print_model_info(model)
model = model.cuda()


optimizer = optim.Adam(model.parameters(), lr, weight_decay=5e-3)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.94)
start_epoch = 0

if CHECKPOINT:
    path_checkpoint = os.path.join(save_checkPath, 'ckpt_best_120.pth')
    checkpoint = torch.load(path_checkpoint)
    start_epoch = checkpoint['epoch']
    model.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])

Train Epochs

In [None]:
tensorboardPath = "./runs"
if not os.path.isdir(tensorboardPath):
    os.mkdir(tensorboardPath)
writer = SummaryWriter(log_dir=tensorboardPath)

for epoch in range(sum_epoch):
    total_loss = 0

    model.train()
    print('Start Train')

    bat = 0
    with tqdm(total=len(echo1_o), desc=f'Epoch {epoch+1}/{sum_epoch}', postfix=dict) as pbar:
        for echo1 in echo1_o:
            
            optimizer.zero_grad()
            with torch.no_grad():
                
                x_chunks = np.array_split(echo1, 4, axis=2)
                segmentation_chunks = np.array_split(segmentations_o[bat], 4, axis=2)

            sl = 0
            for x_chunk in x_chunks:
                with torch.no_grad():
                    chunk =  torch.as_tensor(x_chunk).float().to('cuda')
                    # chunk = chunk.cuda()
                    chunk = (chunk - chunk.mean()) / chunk.std()    #(512,512,160)


                    chunk = chunk.permute(2, 0, 1)               # (B,H,W) (160,512,512)
                    chunk = chunk.unsqueeze(0).unsqueeze(0)      # add a channel dimension (B,C,H,W) (1,1,160,512,512)

                    # segmentation = segmentations_o[bat]                #(x,y,160,6)
                    segmentation = collect_mask(segmentation_chunks[sl], (0, 1, (2, 3), (4, 5)), out_channel_first=False)          #(x,y,160,4)
                    # segmentation = oF.one_hot_to_categorical(segmentation, channel_dim=-1)   #label  [1 1 1 ... 2 2 2]   (x,y,160)

                    segmentation =  torch.as_tensor(segmentation).to('cuda')
                    segmentation = segmentation.permute(2, 0, 1, 3)           #(160,x,y,4)


                pre_out = model(chunk)  #(1,4,40,512,512)
                # logits.append(out["sem_seg_logits"])
                # segmentation_chunk = segmentation_chunks[sl]
                if dice_loss:
                    loss = Dice_Loss(pre_out, segmentation, vnet=True)
                else:
                    loss = BCE_Loss(pre_out, segmentation, vnet=True)

                loss.backward()
                optimizer.step()

                total_loss += loss.item()

                sl += sl

            bat = bat + 1
            pbar.set_postfix(**{
                'cur_loss': loss.item(),
                'averange_loss': total_loss / bat
                })
                
            pbar.update(1)
            
    
    lr_scheduler.step()
    writer.add_scalars(scalar_name, {"Train": total_loss / bat }, epoch)

    if (epoch+1) % 20 == 0:
        
        checkpoint = {
            'epoch':(epoch+1),
            'model':model.state_dict(),
            'optimizer':optimizer.state_dict(),
            'lr_scheduler':lr_scheduler.state_dict()}
        if not os.path.isdir(save_checkPath):
            os.mkdir(save_checkPath)
        torch.save(checkpoint,save_checkPath+'/ckpt_best_%s.pth'%(str(epoch+1)))


# Save
torch.save(model.state_dict(), save_path)

### Train Transformer-Net

Prapare dataset

In [None]:
echo1_o = []
segmentations_o = []

for h5_file in tqdm(image_names_train,desc='process dataset: '):

    image_file = os.path.join(folder_path, h5_file)
    with h5py.File(image_file, "r") as f:
        echo1_f = f["echo1"][()]  # Shape: (x, y, z)
        echo2 = f["echo2"][()]  # Shape: (x, y, z)
        segmentations_f = f["seg"][()]  # Shape: (x, y, z, #classes)
        
    # print(echo1_f.shape[2])
    # if echo1_f.shape[2] == 160:
    echo1_o.append(echo1_f)
    segmentations_o.append(segmentations_f)


echo1_o = np.concatenate(echo1_o,axis=2)
segmentations_o = np.concatenate(segmentations_o,axis=2)

Set args

In [6]:
num_classes = 4
batch_size = 40
sum_epoch = 100
pretrained = False
lr = 1e-3
CHECKPOINT = False
dice_loss = False
save_modelpath = "./save_model/"
save_checkPath = './checkPoint_Transnet_bce/'
save_path = os.path.join(save_modelpath, 'transnet_model_100_BCE.pth')
scalar_name = "Transnet_aver_Loss_bce"

Load model

In [None]:
model = TransNet()

if not pretrained:
    weights_init(model)
else:
    model.load_state_dict(torch.load(os.path.join('./save_model/unet32_model_100.pth'),map_location=DEVICE))
    print("Load model success! ")

print(model)
print('------------------')
# print_model_info(model)
model = model.cuda()

optimizer = optim.Adam(model.parameters(), lr, weight_decay=5e-3)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.94)
start_epoch = 0

if CHECKPOINT:
    path_checkpoint = os.path.join(save_checkPath, 'ckpt_best_120.pth')
    checkpoint = torch.load(path_checkpoint)
    start_epoch = checkpoint['epoch']
    model.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])


Train Epochs

In [None]:
total = echo1_o.shape[2]
x_chunks = np.array_split(echo1_o, total / batch_size, axis=2)
segmentation_chunks = np.array_split(segmentations_o, total / batch_size, axis=2)

tensorboardPath = "./runs"
if not os.path.isdir(tensorboardPath):
    os.mkdir(tensorboardPath)
writer = SummaryWriter(log_dir=tensorboardPath)

for epoch in range(start_epoch, sum_epoch):
    total_loss = 0

    model.train()
    print('Start Train')

    bat = 0
    with tqdm(total=len(x_chunks), desc=f'Epoch {epoch+1}/{sum_epoch}', postfix=dict) as pbar:
        for chunk in x_chunks:

            optimizer.zero_grad()
            with torch.no_grad():
                
                chunk =  torch.as_tensor(chunk).float().to(DEVICE)
                chunk = (chunk - chunk.mean()) / chunk.std()

                chunk = chunk.permute(2, 0, 1)  #(B,H,W) (40,512,512)
                chunk = chunk.unsqueeze(1)      # add a channel dimension (B,C,H,W) (40,1,512,512)
                
                segmentation = segmentation_chunks[bat]                                                            #[512,512,40,6]
                segmentation = collect_mask(segmentation, (0, 1, (2, 3), (4, 5)), out_channel_first=False)         #[512,512,40,4]
                # segmentation = oF.one_hot_to_categorical(segmentation, channel_dim=-1)   #label  [1 1 1 ... 2 2 2]   [512,512,40]

                segmentation =  torch.as_tensor(segmentation).to(DEVICE)
                segmentation = segmentation.permute(2, 0, 1, 3)             #[40,512,512,4]

            pre_out = model(chunk)                                          #[40,4,512,512]
            # pre_out = oF.pred_to_categorical(pre, activation='sigmoid')   # [40,512,512]

            if dice_loss:
                loss = Dice_Loss(pre_out, segmentation)
            else:
                loss = BCE_Loss(pre_out, segmentation)


            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            
            bat = bat + 1
            pbar.set_postfix(**{
            'cur_loss': loss.item(),
            'averange_loss': total_loss / bat
            })
            
            pbar.update(1)
    
    lr_scheduler.step()
    writer.add_scalars(scalar_name, {"Train": total_loss / bat }, epoch)

    if (epoch+1) % 20 == 0:
        
        checkpoint = {
            'epoch':(epoch+1),
            'model':model.state_dict(),
            'optimizer':optimizer.state_dict(),
            'lr_scheduler':lr_scheduler.state_dict()}
        if not os.path.isdir(save_checkPath):
            os.mkdir(save_checkPath)
        torch.save(checkpoint,save_checkPath+'/ckpt_best_%s.pth'%(str(epoch+1)))

# Save  
# save_path = os.path.join(save_modelpath, 'unet32_model_200_BCE.pth')
torch.save(model.state_dict(), save_path)

## Evalution

### Eval function

In [16]:
from scipy.spatial.distance import directed_hausdorff
from skimage.measure import find_contours
import scipy

# Dice Similarity Coefficient (DSC)
def dice_similarity_coefficient(pred, targ):
    intersection = np.sum(pred * targ)
    return (2. * intersection) / (np.sum(pred) + np.sum(targ))

# def dice_similarity_coefficient_new(pred, targ, cls):
#     intersection = np.sum(pred * targ)
#     return (2. * intersection) / (np.sum(pred) + np.sum(targ))

# # Average Symmetric Surface Distance (ASSD)
# def average_symmetric_surface_distance(pred, targ):
#     # ASSD requires surface extraction which is not trivial. This is a placeholder for the correct implementation.
#     # The following uses directed Hausdorff as a rough approximation, but this is not accurate for real ASSD calculations.
#     u_hausdorff = directed_hausdorff(pred, targ)[0]
#     v_hausdorff = directed_hausdorff(targ, pred)[0]
#     return (u_hausdorff + v_hausdorff) / 2.0

def calculate_surface_distances(mask1, mask2, spacing):
    """
    Calculate the distances from the surface of mask1 to mask2.
    """
    contours_mask1 = find_contours(mask1, level=0.5, fully_connected='low')
    contours_mask2 = find_contours(mask2, level=0.5, fully_connected='low')

    if len(contours_mask1) == 0 or len(contours_mask2) == 0:
        return np.array([]), np.array([])

    # Convert contours to a more straightforward list of points
    surface_pts_mask1 = np.vstack(contours_mask1).astype(np.float32) * spacing
    surface_pts_mask2 = np.vstack(contours_mask2).astype(np.float32) * spacing

    distances_mask1_to_mask2 = scipy.spatial.distance.cdist(surface_pts_mask1, surface_pts_mask2)
    distances_mask2_to_mask1 = scipy.spatial.distance.cdist(surface_pts_mask2, surface_pts_mask1)

    return distances_mask1_to_mask2, distances_mask2_to_mask1

def average_symmetric_surface_distance(pred, targ, spacing=(1.0, 1.0, 1.0)):
    """
    Calculate the Average Symmetric Surface Distance (ASSD) between the predicted and target masks.
    :param pred: Predicted mask
    :param targ: Target mask
    :param spacing: The physical spacing between data points in the mask (defaults to 1 in all dimensions)
    """
    assert pred.shape == targ.shape
    assd_per_slice = []
    for i in range(pred.shape[0]):  # Loop over each slice
        distances_pred_to_targ, distances_targ_to_pred = calculate_surface_distances(pred[i], targ[i], spacing)
        if distances_pred_to_targ.size == 0 or distances_targ_to_pred.size == 0:
            continue  # Skip slices with no contours
        assd_per_slice.append(np.mean(np.concatenate([distances_pred_to_targ.min(axis=1), distances_targ_to_pred.min(axis=0)])))
    return np.mean(assd_per_slice) if assd_per_slice else np.nan

# Volumetric Overlap Error (VOE)
def volumetric_overlap_error(pred, targ):
    intersection = np.sum(pred * targ)
    union = np.sum(pred) + np.sum(targ) - intersection
    return 1 - (intersection / union)

# Coefficient of Variation (CV)
def coefficient_of_variation(image):
    return np.std(image) / np.mean(image)

# Function to compute the average of all metrics across all classes
def compute_average_metrics(pred, targ):
    num_classes = pred.shape[-1]
    dsc_scores = [dice_similarity_coefficient(pred[..., i], targ[..., i]) for i in range(num_classes)]
    assd_scores = [average_symmetric_surface_distance(pred[..., i], targ[..., i]) for i in range(num_classes)]
    voe_scores = [volumetric_overlap_error(pred[..., i], targ[..., i]) for i in range(num_classes)]
    cv_scores = [coefficient_of_variation(pred[..., i]) for i in range(num_classes)]

    avg_dsc = np.mean(dsc_scores)
    avg_assd = np.mean(assd_scores)
    avg_voe = np.mean(voe_scores)
    avg_cv = np.mean(cv_scores)

    return avg_dsc, avg_assd, avg_voe, avg_cv


### Start Eval U-Net

Set args

In [38]:
pretrained = True
CHECKPOINT = False
save_modelpath = "./save_model/"
save_checkPath = './checkPoint_model_bce/'
save_path = os.path.join(save_modelpath, 'unet32_model_100_BCE.pth')

Load model

In [None]:
model = UNet_32()

if not pretrained:
    weights_init(model)
else:
    model.load_state_dict(torch.load(os.path.join(save_path),map_location=DEVICE))
    print("Load model success! ")

print(model)
print('------------------')
model = model.cuda()

if CHECKPOINT:
        path_checkpoint = os.path.join(save_checkPath, 'ckpt_best_200.pth')
        checkpoint = torch.load(path_checkpoint)
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['model'])

Forward model

In [None]:
segmentations_o = []
out = []

model.eval()
# for h5_file in tqdm(image_names_test[20:32],desc='process Test: '):
for h5_file in tqdm(image_names_test,desc='process Test: '):
    one_pre = []
    image_file = os.path.join(folder_path, h5_file)
    with h5py.File(image_file, "r") as f:
        echo1_f = f["echo1"][()]  # Shape: (x, y, z)
        echo2 = f["echo2"][()]  # Shape: (x, y, z)
        segmentations_f = f["seg"][()]  # Shape: (x, y, z, #classes)
    
    if echo1_f.shape[2] != 160:
        continue
    
    echo1_f = (echo1_f - echo1_f.mean()) / echo1_f.std()

    segmentations_o.append(segmentations_f)    

    for i in range(echo1_f.shape[2]):
        pic = echo1_f[:,:,i]
        with torch.no_grad():
        
            pic =  torch.as_tensor(pic).float().to('cuda')
            pic = pic.unsqueeze(0).unsqueeze(0)       # add a channel dimension (B,C,H,W) (1,1,512,512)

            pre_out = model(pic)                      

            # print(pre_out[0][0])
            pre_out = oF.pred_to_categorical(pre_out, activation='sigmoid')  # (1, 512, 512)
            pre_out = oF.categorical_to_one_hot(pre_out,num_categories=4)    # (1, 4, 512, 512)
            one_pre.append(pre_out)
    
    
    one_pre = torch.concat(one_pre, dim=0)    # (160, 4, 512, 512)
    # one_pre = one_pre.permute(0,3,4,1,2)
    one_pre = one_pre.permute(2,3,0,1)
    out.append(one_pre)

Process data

In [None]:
result_o = []
for outt in tqdm(out, desc='process Out: '):
    outt = outt.cpu().numpy()
    result_o.append(outt)


print('result_o stack')
result_o = np.stack(result_o)

segs = []
for seg in tqdm(segmentations_o,desc='process Seg: '):
    seg = collect_mask(seg, (0, 1, (2, 3), (4, 5)), out_channel_first=False) 
    segs.append(seg)

print('segs stack')
result_true_o = np.stack(segs)


print(result_true_o.shape)
print(result_o.shape)

Eval

In [None]:
num_samples = result_o.shape[0]
num_classes = result_true_o.shape[-1]
# Initialize lists to store metric scores for each sample
dsc_list, assd_list, voe_list, cv_list = [], [], [], []

# Compute metrics for each sample
for n in tqdm([ i for i in range(num_samples)], desc='process eval: '):
    dsc_scores = [dice_similarity_coefficient(result_o[n, ..., i], result_true_o[n, ..., i]) for i in range(num_classes)]
    # assd_scores = [average_symmetric_surface_distance(pred[n, ..., i], targ[n, ..., i]) for i in range(num_classes)]
    voe_scores = [volumetric_overlap_error(result_o[n, ..., i], result_true_o[n, ..., i]) for i in range(num_classes)]
    cv_scores = [coefficient_of_variation(result_o[n, ..., i]) for i in range(num_classes)]
    
    dsc_list.append(np.mean(dsc_scores))
    # assd_list.append(np.mean(assd_scores))
    voe_list.append(np.mean(voe_scores))
    cv_list.append(np.mean(cv_scores))

# Calculate average across all samples
avg_dsc = np.mean(dsc_list)
# avg_assd = np.mean(assd_list)
avg_voe = np.mean(voe_list)
avg_cv = np.mean(cv_list)

print("Eval U-Net:")
print(avg_dsc,avg_voe,avg_cv)

### Start Eval V-Net

Set args

In [61]:
pretrained = True
CHECKPOINT = False
save_modelpath = "./save_model/"
save_checkPath = './checkPoint_Vmodel_bce/'
save_path = os.path.join(save_modelpath, 'Vnet80_model_100_BCE.pth')

Load model

In [None]:
model = VNet_80()

if not pretrained:
    weights_init(model)
else:
    model.load_state_dict(torch.load(os.path.join(save_path),map_location=DEVICE))
    print("Load model success! ")

print(model)
print('------------------')
model = model.cuda()

if CHECKPOINT:
        path_checkpoint = os.path.join(save_checkPath, 'ckpt_best_200.pth')
        checkpoint = torch.load(path_checkpoint)
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['model'])

Forward model

In [None]:
result_o = []
result_true_o = []
# for h5_file in tqdm(image_names_test,desc='process dataset: '):
for h5_file in tqdm(image_names_test,desc='process Test: '):

        image_file = os.path.join(folder_path, h5_file)
        with h5py.File(image_file, "r") as f:
            echo1_f = f["echo1"][()]  # Shape: (x, y, z)
            echo2 = f["echo2"][()]  # Shape: (x, y, z)
            segmentations_f = f["seg"][()]  # Shape: (x, y, z, #classes)
            
        # print(echo1_f.shape[2])
        if echo1_f.shape[2] == 160:
            echo1_o.append(echo1_f)
            segmentations_o.append(segmentations_f)
        else:
             continue
        
        optimizer.zero_grad()
        with torch.no_grad():
            x_chunks = np.array_split(echo1_f, 4, axis=2)
            # segmentation_chunks = np.array_split(segmentations_o[bat], 4, axis=2)

        sl = 0
        result = []
        # result_true = []
        for x_chunk in x_chunks:
            with torch.no_grad():
                chunk =  torch.as_tensor(x_chunk).float().to('cuda')
                chunk = (chunk - chunk.mean()) / chunk.std()        # (512,512,40)


                chunk = chunk.permute(2, 0, 1)                      # (B,H,W) (160,512,512)
                chunk = chunk.unsqueeze(0).unsqueeze(0)             # add a channel dimension (B,C,H,W) (1,1,40,512,512)

                pre_out = model(chunk)                              # (1,4,40,512,512)
                result.append(pre_out)  
                # result_true.append(segmentation)
                sl += sl

        result = torch.concat(result, dim=2)    #(1,4,160,512,512)
        # result_true = torch.concat(result_true, dim=3)
        result_o.append(result)

Process data

In [67]:
print('precess Out')
result_o = torch.concat(result_o, dim=0)    # (5,4,160,512,512)
result_o =result_o.permute(0,3,4,2,1)       # (5,512,512,160,4)
result_o = result_o.cpu()
result_o = result_o.numpy()

segs = []
for seg in tqdm(segmentations_o, desc='process Seg: '):
    seg = collect_mask(seg, (0, 1, (2, 3), (4, 5)), out_channel_first=False)  # (512,512,160,6) to (512,512,160,4)
    segs.append(seg)

result_true_o = np.stack(segs)             # (5,512,512,160,4)

Eval

In [None]:
num_samples = result_o.shape[0]
num_classes = result_true_o.shape[-1]
# Initialize lists to store metric scores for each sample
dsc_list, assd_list, voe_list, cv_list = [], [], [], []

# Compute metrics for each sample
for n in tqdm([ i for i in range(num_samples)], desc='process eval: '):
    dsc_scores = [dice_similarity_coefficient(result_o[n, ..., i], result_true_o[n, ..., i]) for i in range(num_classes)]
    # assd_scores = [average_symmetric_surface_distance(pred[n, ..., i], targ[n, ..., i]) for i in range(num_classes)]
    voe_scores = [volumetric_overlap_error(result_o[n, ..., i], result_true_o[n, ..., i]) for i in range(num_classes)]
    cv_scores = [coefficient_of_variation(result_o[n, ..., i]) for i in range(num_classes)]
    
    dsc_list.append(np.mean(dsc_scores))
    # assd_list.append(np.mean(assd_scores))
    voe_list.append(np.mean(voe_scores))
    cv_list.append(np.mean(cv_scores))

# Calculate average across all samples
avg_dsc = np.mean(dsc_list)
# avg_assd = np.mean(assd_list)
avg_voe = np.mean(voe_list)
avg_cv = np.mean(cv_list)

print("Eval V-Net:")
print(avg_dsc,avg_voe,avg_cv)

### Start Eval TransNet

Set args

In [9]:
pretrained = True
CHECKPOINT = False
save_modelpath = "./save_model/"
save_checkPath = './checkPoint_Transnet_bce/'
save_path = os.path.join(save_modelpath, 'transnet_model_100_BCE.pth')

Load model

In [None]:
model = TransNet()

if not pretrained:
    weights_init(model)
else:
    model.load_state_dict(torch.load(os.path.join(save_path),map_location=DEVICE))
    print("Load model success! ")

print(model)
print('------------------')
model = model.cuda()

if CHECKPOINT:
        path_checkpoint = os.path.join(save_checkPath, 'ckpt_best_200.pth')
        checkpoint = torch.load(path_checkpoint)
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['model'])

Forward model

In [None]:
segmentations_o = []
out = []

model.eval()
# for h5_file in tqdm(image_names_test[20:32],desc='process Test: '):
for h5_file in tqdm(image_names_test,desc='process Test: '):
    one_pre = []
    image_file = os.path.join(folder_path, h5_file)
    with h5py.File(image_file, "r") as f:
        echo1_f = f["echo1"][()]  # Shape: (x, y, z)
        echo2 = f["echo2"][()]  # Shape: (x, y, z)
        segmentations_f = f["seg"][()]  # Shape: (x, y, z, #classes)
    
    if echo1_f.shape[2] != 160:
        continue
    
    echo1_f = (echo1_f - echo1_f.mean()) / echo1_f.std()

    segmentations_o.append(segmentations_f)    

    for i in range(echo1_f.shape[2]):
        pic = echo1_f[:,:,i]
        with torch.no_grad():
        
            pic =  torch.as_tensor(pic).float().to('cuda')
            pic = pic.unsqueeze(0).unsqueeze(0)       # add a channel dimension (B,C,H,W) (1,1,512,512)

            pre_out = model(pic)                      # (1, 4, 512, 512)
            

            pre_out = oF.pred_to_categorical(pre_out, activation='sigmoid')  # (1, 512, 512)

            pre_out = oF.categorical_to_one_hot(pre_out,num_categories=4)    # (1, 4, 512, 512)

            one_pre.append(pre_out)
    
    
    one_pre = torch.concat(one_pre, dim=0)    # (160, 4, 512, 512)
    one_pre = one_pre.permute(2,3,0,1)
    out.append(one_pre)

Process data

In [None]:
result_o = []
for outt in tqdm(out, desc='process Out: '):
    outt = outt.cpu().numpy()
    result_o.append(outt)


print('result_o stack')
result_o = np.stack(result_o)

segs = []
for seg in tqdm(segmentations_o,desc='process Seg: '):
    seg = collect_mask(seg, (0, 1, (2, 3), (4, 5)), out_channel_first=False) 
    segs.append(seg)

print('segs stack')
result_true_o = np.stack(segs)


print(result_true_o.shape)
print(result_o.shape)

Eval

In [None]:
num_samples = result_o.shape[0]
num_classes = result_true_o.shape[-1]
# Initialize lists to store metric scores for each sample
dsc_list, assd_list, voe_list, cv_list = [], [], [], []

# Compute metrics for each sample
for n in tqdm([ i for i in range(num_samples)], desc='process eval: '):
    dsc_scores = [dice_similarity_coefficient(result_o[n, ..., i], result_true_o[n, ..., i]) for i in range(num_classes)]
    # assd_scores = [average_symmetric_surface_distance(pred[n, ..., i], targ[n, ..., i]) for i in range(num_classes)]
    voe_scores = [volumetric_overlap_error(result_o[n, ..., i], result_true_o[n, ..., i]) for i in range(num_classes)]
    cv_scores = [coefficient_of_variation(result_o[n, ..., i]) for i in range(num_classes)]
    
    dsc_list.append(np.mean(dsc_scores))
    # assd_list.append(np.mean(assd_scores))
    voe_list.append(np.mean(voe_scores))
    cv_list.append(np.mean(cv_scores))

# Calculate average across all samples
avg_dsc = np.mean(dsc_list)
# avg_assd = np.mean(assd_list)
avg_voe = np.mean(voe_list)
avg_cv = np.mean(cv_list)

print("Eval TransNet:")
print(avg_dsc,avg_voe,avg_cv)

## Plt

Load data

In [None]:
test_data = []

image_file = os.path.join(folder_path, image_names_test[0])
with h5py.File(image_file, "r") as f:
    echo1_f = f["echo1"][()]  # Shape: (x, y, z)
    # echo2 = f["echo2"][()]  # Shape: (x, y, z)
    segmentations_f = f["seg"][()]  # Shape: (x, y, z, #classes)

echo1_f = (echo1_f - echo1_f.mean()) / echo1_f.std()

test_data.append(echo1_f)

echo1_f = []
test_data = np.array(test_data)

print(test_data.shape)
print(segmentations_f.shape)

Load Models

In [None]:
model_0 = TransNet()
model_0.load_state_dict(torch.load(os.path.join('./save_model/TransNet_model_100.pth'),map_location=DEVICE))
print("Load model0 success! ")

model_1 = VNet_80()
model_1.load_state_dict(torch.load(os.path.join('./save_model/VNet_80_model_100.pth'),map_location=DEVICE))
print("Load model1 success! ")

model_2 = UNet_32()
model_2.load_state_dict(torch.load(os.path.join('./save_model/unet32_model_100.pth'),map_location=DEVICE))
print("Load model2 success! ")


model_0 = model_0.to(DEVICE)
model_1 = model_1.to(DEVICE)
model_2 = model_2.to(DEVICE)

Forward Models

In [None]:

pred = np.zeros((3, 512, 512, 160))
for i in range(1):
    for j in range(160):
        echo1_sl = test_data[i,:,:,j]
        echo1_sl = torch.as_tensor(echo1_sl).unsqueeze(0).unsqueeze(0).float().to(DEVICE)

        with torch.no_grad():
            logits = model_0({"image": echo1_sl})["sem_seg_logits"]
            logits = oF.pred_to_categorical(logits, activation='sigmoid')
            # logits = oF.categorical_to_one_hot(logits,3,num_categories=4)

            # logits = prediction_2
            # logits = logits1.permute(0,2,3,1)
            # prediction = oF.pred_to_categorical(logits, activation='sigmoid').squeeze(0)

            logits = logits.cpu()
            pred[0,:,:,j] = logits[0]

            logits = model_1({"image": echo1_sl})["sem_seg_logits"]
            logits = oF.pred_to_categorical(logits, activation='sigmoid')
            logits = logits.cpu()
            pred[1,:,:,j] = logits[0]

            logits = model_2({"image": echo1_sl})["sem_seg_logits"]
            logits = oF.pred_to_categorical(logits, activation='sigmoid')
            logits = logits.cpu()
            pred[2,:,:,j] = logits[0]


Eval

In [None]:
sl = 32
E1 = pred[0]  #[512, 512, 160]
prediction_0 = E1[:, :, sl]
E2 = pred[1]  #[512, 512, 160]
prediction_1 = E2[:, :, sl]
RSS = pred[2]  #[512, 512, 160]
prediction_2 = RSS[:, :, sl]
gt_seg_sl = segmentations_f[:, :, sl]                            #[512, 512, 160]
input = test_data[0,..., sl]

sl = 150
E1 = pred[0]  #[512, 512, 160]
prediction_01 = E1[:, sl, :]
E2 = pred[1]  #[512, 512, 160]
prediction_11 = E2[:, sl, :]
RSS = pred[2]  #[512, 512, 160]
prediction_21 = RSS[:, sl, :]
gt_seg_sl_1 = segmentations_f[:, sl, :]                    
input_1 = test_data[0, :, sl, :]


sl = 190
E1 = pred[0]  #[512, 512, 160]
prediction_02 = E1[sl, :, :]
E2 = pred[1]  #[512, 512, 160]
prediction_12 = E2[sl, :, :]
RSS = pred[2]  #[512, 512, 160]
prediction_22 = RSS[sl, :, :]
gt_seg_sl_2 = segmentations_f[sl, :, :]         
input_2 = test_data[0, sl, :, :]


_, axs = plt.subplots(3, 5, figsize=(15,12))
for idx, (data, title) in enumerate([
  (input.squeeze(), "Input"), (prediction_0, "TransUNet"),(prediction_1, "Vnet"), (prediction_2, "Unet "),  (gt_seg_sl, "Ground truth")
]):
    ax = axs[0, idx]
    ax.imshow(data.squeeze(), cmap="gray" if idx == 0 else None)
    ax.set_title(title, fontsize=20)
    ax.axis("off")

for idx, (data, title) in enumerate([
  (input_1.squeeze(), "Input"), (prediction_01, "TransUNet"),(prediction_11, "Vnet"), (prediction_21, "Unet "),  (gt_seg_sl_1, "Ground truth")
]):
    ax = axs[1, idx]
    ax.imshow(data.squeeze(), cmap="gray" if idx == 0 else None)
    ax.set_title(title, fontsize=20)
    ax.axis("off")

for idx, (data, title) in enumerate([
  (input_2.squeeze(), "Input"), (prediction_02, "TransUNet"),(prediction_12, "Vnet"), (prediction_22, "Unet "),  (gt_seg_sl_2, "Ground truth")
]):
    ax = axs[2, idx]
    ax.imshow(data.squeeze(), cmap="gray" if idx == 0 else None)
    ax.set_title(title, fontsize=20)
    ax.axis("off")

plt.show()