In [56]:
from monai.utils import set_determinism, ensure_tuple_rep
from monai.networks.nets.swin_unetr import SwinTransformer
from monai.networks.blocks.squeeze_and_excitation import ChannelSELayer
from monai.utils import UpsampleMode
from monai.handlers import CheckpointSaver
from monai.networks.blocks.unetr_block import  UnetrUpBlock
from monai.networks.blocks.dynunet_block import UnetOutBlock
from monai.inferers import SlidingWindowInferer
from monai.metrics import DiceMetric
from monai.losses import DiceLoss
from monai.transforms import (
    AsDiscrete,
    Activations,
    Compose
)


import numpy as np
import torch
import torch.nn as nn
import os
import datetime
from typing import Optional, Sequence, Tuple, Type, Union, Dict

In [46]:
set_determinism(seed=2023)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = True
x = datetime.datetime.now()
root_dir = x.strftime("%m_%d_%Y %H%M%S")
if not os.path.exists(root_dir):
    os.makedirs(root_dir)

print(root_dir)

07_15_2024 175000


In [10]:
class ShuffleBlock(nn.Module):
    def __init__(self, groups=4):
        super(ShuffleBlock, self).__init__()
        self.groups = groups

    def forward(self, x):
        '''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]'''
       
        b, N, C, H, W = x.size()
        g = self.groups
        return x.view(b, N, g, C//g, H, W).permute(0, 2, 1, 3, 4,5).reshape(b,N, C, H, W)



In [13]:
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch,kernel_size = 3, stride = 1,
                 padding = 1, dilation = 1, groups = 1, bias = True):
        super(DoubleConv, self).__init__()
        
        
        self.se_layer = ChannelSELayer(
            spatial_dims=3, in_channels=out_ch, r=2, acti_type_1="relu", acti_type_2="sigmoid"
        )
        
        self.relu=nn.LeakyReLU(inplace=True)
    
        self.cov = nn.Conv3d(in_channels = in_ch, out_channels = out_ch, kernel_size = kernel_size,
                             stride = stride, padding = padding, dilation = dilation, groups = groups, bias = bias)
        self.bn = nn.BatchNorm3d(num_features = out_ch)
       

    def forward(self, input):
        
        return self.relu(self.se_layer(  self.bn(self.cov(input)) )  )

In [40]:

class E_CATBraTS(nn.Module):
    """

    """
    def __init__(
        self,
        img_size: Union[Sequence[int], int],
        in_channels: int,
        out_channels: int,
        depths: Sequence[int] = (2, 2, 2, 2),
        num_heads: Sequence[int] = (3, 6, 12, 24),
        feature_size: int = 24,
        norm_name: Union[Tuple, str] = "instance",
        drop_rate: float = 0.0,
        attn_drop_rate: float = 0.0,
        dropout_path_rate: float = 0.0,
        normalize: bool = True,
        use_checkpoint: bool = False,
        spatial_dims: int = 3,
        downsample="merging",
        init_filters: int = 8,
        dropout_prob: Optional[float] = None,
        act: Union[Tuple, str] = ("RELU", {"inplace": True}),
        norm: Union[Tuple, str] = ("GROUP", {"num_groups": 8}),
        num_groups: int = 8,
        use_conv_final: bool = True,
        blocks_down: tuple = (1, 2, 2, 4),
        blocks_up: tuple = (1, 1, 1),
        upsample_mode: Union[UpsampleMode, str] = UpsampleMode.NONTRAINABLE,
        num_class = 2, 
        with_BN = True, 
        channel_width = 4
    ) -> None:
        """


        """

        super().__init__()

        img_size = ensure_tuple_rep(img_size, spatial_dims)
        patch_size = ensure_tuple_rep(2, spatial_dims)
        window_size = ensure_tuple_rep(7, spatial_dims)

        if not (spatial_dims == 2 or spatial_dims == 3):
            raise ValueError("spatial dimension should be 2 or 3.")

        for m, p in zip(img_size, patch_size):
            for i in range(5):
                if m % np.power(p, i + 1) != 0:
                    raise ValueError("input image size (img_size) should be divisible by stage-wise image resolution.")

        if not (0 <= drop_rate <= 1):
            raise ValueError("dropout rate should be between 0 and 1.")

        if not (0 <= attn_drop_rate <= 1):
            raise ValueError("attention dropout rate should be between 0 and 1.")

        if not (0 <= dropout_path_rate <= 1):
            raise ValueError("drop path rate should be between 0 and 1.")

        if feature_size % 12 != 0:
            raise ValueError("feature_size should be divisible by 12.")

        self.normalize = normalize
        self.act = act
        self.swinViT = SwinTransformer(
            in_chans=in_channels,
            embed_dim=feature_size,
            window_size=window_size,
            patch_size=patch_size,
            depths=depths,
            num_heads=num_heads,
            mlp_ratio=4.0,
            qkv_bias=True,
            drop_rate=drop_rate,
            attn_drop_rate=attn_drop_rate,
            drop_path_rate=dropout_path_rate,
            norm_layer=nn.LayerNorm,
            use_checkpoint=use_checkpoint,
            spatial_dims=spatial_dims
        )

    
     
        self.cov3d_12_en = DoubleConv(in_ch = in_channels, out_ch = 1 * feature_size, kernel_size = 3, padding = 1)
        
        self.cov3d_22_en = DoubleConv(in_ch =  feature_size, out_ch =  feature_size, kernel_size = 3, padding = 1)
        
        self.cov3d_32_en = DoubleConv(in_ch = 2 * feature_size, out_ch = 2 * feature_size, kernel_size = 3, padding = 1)
        
        self.cov3d_42_en = DoubleConv(in_ch = 4 * feature_size, out_ch = 4 * feature_size, kernel_size = 3, padding = 1)
        
        self.cov3d_4_en = DoubleConv(in_ch = 8 * feature_size, out_ch = 8 * feature_size, kernel_size = 3, padding = 1)
  
        self.cov3d_52_en = DoubleConv(in_ch = 16 * feature_size, out_ch = 16 * feature_size, kernel_size = 3, padding = 1)
  
        self.decoder5 = UnetrUpBlock(
            spatial_dims=spatial_dims,
            in_channels=16 * feature_size,
            out_channels=8 * feature_size,
            kernel_size=3,
            upsample_kernel_size=2,
            norm_name=norm_name,
            res_block=True,
        )

        self.decoder4 = UnetrUpBlock(
            spatial_dims=spatial_dims,
            in_channels=feature_size * 8,
            out_channels=feature_size * 4,
            kernel_size=3,
            upsample_kernel_size=2,
            norm_name=norm_name,
            res_block=True,
        )

        self.decoder3 = UnetrUpBlock(
            spatial_dims=spatial_dims,
            in_channels=feature_size * 4,
            out_channels=feature_size * 2,
            kernel_size=3,
            upsample_kernel_size=2,
            norm_name=norm_name,
            res_block=True,
        )
        self.decoder2 = UnetrUpBlock(
            spatial_dims=spatial_dims,
            in_channels=feature_size * 2,
            out_channels=feature_size,
            kernel_size=3,
            upsample_kernel_size=2,
            norm_name=norm_name,
            res_block=True,
        )

        self.decoder1 = UnetrUpBlock(
            spatial_dims=spatial_dims,
            in_channels=feature_size,
            out_channels=feature_size,
            kernel_size=3,
            upsample_kernel_size=2,
            norm_name=norm_name,
            res_block=True,
        )
        self.out = UnetOutBlock(spatial_dims=spatial_dims, in_channels=feature_size, out_channels=out_channels)
        self.shuffle1 = ShuffleBlock()
        self.shuffle2 = ShuffleBlock()
        self.shuffle3 = ShuffleBlock()
        self.shuffle4 = ShuffleBlock()
        self.shuffle5 = ShuffleBlock()

    def load_from(self, weights):

        with torch.no_grad():
            self.swinViT.patch_embed.proj.weight.copy_(weights["state_dict"]["module.patch_embed.proj.weight"])
            self.swinViT.patch_embed.proj.bias.copy_(weights["state_dict"]["module.patch_embed.proj.bias"])
            for bname, block in self.swinViT.layers1[0].blocks.named_children():
                block.load_from(weights, n_block=bname, layer="layers1")
            self.swinViT.layers1[0].downsample.reduction.weight.copy_(
                weights["state_dict"]["module.layers1.0.downsample.reduction.weight"]
            )
            self.swinViT.layers1[0].downsample.norm.weight.copy_(
                weights["state_dict"]["module.layers1.0.downsample.norm.weight"]
            )
            self.swinViT.layers1[0].downsample.norm.bias.copy_(
                weights["state_dict"]["module.layers1.0.downsample.norm.bias"]
            )
            for bname, block in self.swinViT.layers2[0].blocks.named_children():
                block.load_from(weights, n_block=bname, layer="layers2")
            self.swinViT.layers2[0].downsample.reduction.weight.copy_(
                weights["state_dict"]["module.layers2.0.downsample.reduction.weight"]
            )
            self.swinViT.layers2[0].downsample.norm.weight.copy_(
                weights["state_dict"]["module.layers2.0.downsample.norm.weight"]
            )
            self.swinViT.layers2[0].downsample.norm.bias.copy_(
                weights["state_dict"]["module.layers2.0.downsample.norm.bias"]
            )
            for bname, block in self.swinViT.layers3[0].blocks.named_children():
                block.load_from(weights, n_block=bname, layer="layers3")
            self.swinViT.layers3[0].downsample.reduction.weight.copy_(
                weights["state_dict"]["module.layers3.0.downsample.reduction.weight"]
            )
            self.swinViT.layers3[0].downsample.norm.weight.copy_(
                weights["state_dict"]["module.layers3.0.downsample.norm.weight"]
            )
            self.swinViT.layers3[0].downsample.norm.bias.copy_(
                weights["state_dict"]["module.layers3.0.downsample.norm.bias"]
            )
            for bname, block in self.swinViT.layers4[0].blocks.named_children():
                block.load_from(weights, n_block=bname, layer="layers4")
            self.swinViT.layers4[0].downsample.reduction.weight.copy_(
                weights["state_dict"]["module.layers4.0.downsample.reduction.weight"]
            )
            self.swinViT.layers4[0].downsample.norm.weight.copy_(
                weights["state_dict"]["module.layers4.0.downsample.norm.weight"]
            )
            self.swinViT.layers4[0].downsample.norm.bias.copy_(
                weights["state_dict"]["module.layers4.0.downsample.norm.bias"]
            )

    def forward(self, x_in):
     
        
        hidden_states_out = self.swinViT(x_in, self.normalize)
        enc0 = self.cov3d_12_en(x_in)
        
        enc1 = self.shuffle1(hidden_states_out[0])
        enc1 = self.cov3d_22_en(enc1)
        
        enc2 = self.shuffle2(hidden_states_out[1])
        enc2 = self.cov3d_32_en(enc2)
        
        enc3 = self.shuffle3(hidden_states_out[2])
        enc3 = self.cov3d_42_en(enc3)
        
        enc4 = self.shuffle4(hidden_states_out[3])
        enc4 = self.cov3d_4_en(enc4)
        
        dec4 = self.shuffle5(hidden_states_out[4])
        dec4 = self.cov3d_52_en(dec4)
        
        
        dec3 = self.decoder5(dec4,enc4)
        dec2 = self.decoder4(dec3, enc3)
        dec1 = self.decoder3(dec2, enc2)
        dec0 = self.decoder2(dec1, enc1)
        out = self.decoder1(dec0, enc0)
        logits = self.out(out)
     
        return logits

In [24]:
#SWIN TRANSFORMER AVERAGEMETER
class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = np.where(self.count > 0, self.sum / self.count, self.sum)



In [25]:

def LoadDataset(img_path,roi):
    """TOTAL 500 cases, 350 training, 100 val, 50 eval""" 
    patients = os.listdir(img_path)
    random.shuffle(patients)
    images = []
 
    for patient in patients:
        
        p = patient.removesuffix("nifti")
        
        T1 = img_path+"/"+patient+"/"+p+"T1.nii.gz"
        T1GD = img_path+"/"+patient+"/"+p+"T1c.nii.gz"
        T2 = img_path+"/"+patient+"/"+p+"T2.nii.gz"
        TFlair = img_path+"/"+patient+"/"+p+"FLAIR.nii.gz"
     
        mask = img_path+"/"+patient+"/"+p+"tumor_segmentation.nii.gz"
        
        images.append({"image":[T1,T1GD,T2,TFlair],"label":[mask]})


    
        
    train_dt = images[0:350]
    val_dt = images[350:451]
    test_dt = images[451:500]

    train_transform = transforms.Compose(
        [
            transforms.LoadImaged(keys=["image", "label"]),
            transforms.ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
            transforms.CropForegroundd(
                keys=["image", "label"],
                source_key="image",
                k_divisible=[roi[0], roi[1], roi[2]],
            ),
              transforms.RandSpatialCropd(
                keys=["image", "label"],
                roi_size=[roi[0], roi[1], roi[2]],
                random_size=False,
            ),
            transforms.RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
            transforms.RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
            transforms.RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
            transforms.NormalizeIntensityd(
                keys="image", nonzero=True, channel_wise=True
            ),
            transforms.RandScaleIntensityd(keys="image", factors=0.1, prob=1.0),
            transforms.RandShiftIntensityd(keys="image", offsets=0.1, prob=1.0),

        ]
    )
    val_transform = transforms.Compose(
        [
            transforms.LoadImaged(keys=["image", "label"]),
            transforms.ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
            transforms.NormalizeIntensityd(
                keys="image", nonzero=True, channel_wise=True
            ),
        ]
    )

    train_ds = data.Dataset(data=train_dt, transform=train_transform)

    train_loader = data.DataLoader(
        train_ds,
        batch_size=batch_size,
        shuffle=True,
        num_workers=0,
        pin_memory=True,
    )
    val_ds = data.Dataset(data=val_dt, transform=val_transform)
    val_loader = data.DataLoader(
        val_ds,
        batch_size=1,
        shuffle=False,
        num_workers=0,
        pin_memory=True,
    )
    
    test_transform = transforms.Compose(
    [
        transforms.LoadImaged(keys=["image", "label"]),
        transforms.ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
        transforms.NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
    ]
)

    test_ds = data.Dataset(data=test_dt, transform=test_transform)

    test_loader = data.DataLoader(
        test_ds,
        batch_size=1,
        shuffle=False,
        num_workers=0,
        pin_memory=True,
    )

    return train_loader, val_loader,test_loader,test_ds


In [58]:

model = E_CATBraTS(
    img_size=roi,
    in_channels=4,
    out_channels=3,
    feature_size=48,
    drop_rate=0.0,
    attn_drop_rate=0.0,
    dropout_path_rate=0.0,
    use_checkpoint=True,
).to(device)


dice_metric = DiceMetric(include_background=True, reduction="mean")
loss_function = DiceLoss(sigmoid=True)



post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])

window_inferer = SlidingWindowInferer(roi_size = roi, overlap=0.5)

optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5, weight_decay=1e-5)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epochs)

In [None]:
def train_epoch(model, loader, optimizer, epoch, loss_func):
    model.train()
    start_time = time.time()
    run_loss = AverageMeter()
    for idx, batch_data in enumerate(loader):
        data, target = batch_data["image"].to(device), batch_data["label"].to(device)
        logits = model(data)
        loss = loss_func(logits, target)
        loss.backward()
        optimizer.step()
        run_loss.update(loss.item(), n=batch_size)
        print(
            "Epoch {}/{} {}/{}".format(epoch, max_epochs, idx, len(loader)),
            "loss: {:.4f}".format(run_loss.avg),
            "time {:.2f}s".format(time.time() - start_time),
        )
        start_time = time.time()
    return run_loss.avg


def val_epoch(
    model,
    loader,
    epoch,
    acc_func,
    model_inferer=None,
    post_sigmoid=None,
    post_pred=None,
):
    model.eval()
    start_time = time.time()
    run_acc = AverageMeter()

    with torch.no_grad():
        for idx, batch_data in enumerate(loader):
            data, target = batch_data["image"].to(device), batch_data["label"].to(
                device
            )
            logits = model_inferer(data)
            val_labels_list = decollate_batch(target)
            val_outputs_list = decollate_batch(logits)
            val_output_convert = [
                post_pred(post_sigmoid(val_pred_tensor))
                for val_pred_tensor in val_outputs_list
            ]
            acc_func.reset()
            acc_func(y_pred=val_output_convert, y=val_labels_list)
            acc, not_nans = acc_func.aggregate()
            run_acc.update(acc.cpu().numpy(), n=not_nans.cpu().numpy())
            Dice_TC = run_acc.avg[0]
            Dice_WT = run_acc.avg[1]
            Dice_ET = run_acc.avg[2]
            print(
                "Val {}/{} {}/{}".format(epoch, max_epochs, idx, len(loader)),
                ", Dice_TC:",
                Dice_TC,
                ", Dice_WT:",
                Dice_WT,
                ", Dice_ET:",
                Dice_ET,
                ", time {:.2f}s".format(time.time() - start_time),
            )
            start_time = time.time()

    return run_acc.avg

In [None]:
start_epoch = 0

val_acc_max = 0.0
Dices_TC = []
Dices_WT = []
Dices_ET = []
Dices_avg = []
loss_epochs = []
trains_epoch = []
for epoch in range(start_epoch, max_epochs):
    print(time.ctime(), "Epoch:", epoch)
    epoch_time = time.time()
    train_loss = train_epoch(
        model,
        train_loader,
        optimizer,
        epoch=epoch,
        loss_func=loss_func,
    )
    print(
        "Final training  {}/{}".format(epoch, max_epochs - 1),
        "loss: {:.4f}".format(train_loss),
        "time {:.2f}s".format(time.time() - epoch_time),
    )

    if (epoch + 1) % val_every == 0 or epoch == 0:
        loss_epochs.append(train_loss)
        trains_epoch.append(int(epoch))
        epoch_time = time.time()
        val_acc = val_epoch(
            model,
            val_loader,
            epoch=epoch,
            acc_func=acc_func,
            model_inferer=model_inferer,
            post_sigmoid=post_sigmoid,
            post_pred=post_pred,
        )
        Dice_TC = val_acc[0]
        Dice_WT = val_acc[1]
        Dice_ET = val_acc[2]
        val_avg_acc = np.mean(val_acc)
        print(
            "Final validation stats {}/{}".format(epoch, max_epochs - 1),
            ", Dice_TC:",
            Dice_TC,
            ", Dice_WT:",
            Dice_WT,
            ", Dice_ET:",
            Dice_ET,
            ", Dice_Avg:",
            val_avg_acc,
            ", time {:.2f}s".format(time.time() - epoch_time),
        )
        Dices_TC.append(Dice_TC)
        Dices_WT.append(Dice_WT)
        Dices_ET.append(Dice_ET)
        Dices_avg.append(val_avg_acc)
        if val_avg_acc > val_acc_max:
            print("new best ({:.6f} --> {:.6f}). ".format(val_acc_max, val_avg_acc))
            val_acc_max = val_avg_acc
            save_checkpoint(
                model,
                epoch,
                best_acc=val_acc_max,
            )
        scheduler.step()
print("Training Finished !, Best Accuracy: ", val_acc_max)
