## **Train.py**

## **BraTS.py**

In [1]:
modalities = ('flair', 't1ce', 't1', 't2')

In [2]:
import os
import torch
from torch.utils.data import Dataset
import random
import numpy as np
from torchvision.transforms import transforms
import pickle
from scipy import ndimage
import pickle
import os
import numpy as np
import nibabel as nib

def pkload(fname):
    with open(fname, 'rb') as f:
        return pickle.load(f)


class MaxMinNormalization(object):
    def __call__(self, sample):
        image = sample['image']
        label = sample['label']
        Max = np.max(image)
        Min = np.min(image)
        image = (image - Min) / (Max - Min)

        return {'image': image, 'label': label}


class Random_Flip(object):
    def __call__(self, sample):
        image = sample['image']
        label = sample['label']
        if random.random() < 0.5:
            image = np.flip(image, 0)
            label = np.flip(label, 0)
        if random.random() < 0.5:
            image = np.flip(image, 1)
            label = np.flip(label, 1)
        if random.random() < 0.5:
            image = np.flip(image, 2)
            label = np.flip(label, 2)

        return {'image': image, 'label': label}


class Random_Crop(object):
    def __call__(self, sample):
        image = sample['image']
        label = sample['label']
        H = random.randint(0, 240 - 128)
        W = random.randint(0, 240 - 128)
        D = random.randint(0, 160 - 128)

        image = image[H: H + 128, W: W + 128, D: D + 128, ...]
        label = label[..., H: H + 128, W: W + 128, D: D + 128]

        return {'image': image, 'label': label}


class Random_intencity_shift(object):
    def __call__(self, sample, factor=0.1):
        image = sample['image']
        label = sample['label']

        scale_factor = np.random.uniform(1.0-factor, 1.0+factor, size=[1, image.shape[1], 1, image.shape[-1]])
        shift_factor = np.random.uniform(-factor, factor, size=[1, image.shape[1], 1, image.shape[-1]])

        image = image*scale_factor+shift_factor

        return {'image': image, 'label': label}


class Random_rotate(object):
    def __call__(self, sample):
        image = sample['image']
        label = sample['label']

        angle = round(np.random.uniform(-10, 10), 2)
        image = ndimage.rotate(image, angle, axes=(0, 1), reshape=False)
        label = ndimage.rotate(label, angle, axes=(0, 1), reshape=False)

        return {'image': image, 'label': label}


class Pad(object):
    def __call__(self, sample):
        image = sample['image']
        label = sample['label']

        image = np.pad(image, ((0, 0), (0, 0), (0, 5), (0, 0)), mode='constant')
        label = np.pad(label, ((0, 0), (0, 0), (0, 5)), mode='constant')
        return {'image': image, 'label': label}
    #(240,240,155)>(240,240,160)


class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""
    def __call__(self, sample):
        image = sample['image']
        image = np.ascontiguousarray(image.transpose(3, 0, 1, 2))
        label = sample['label']
        label = np.ascontiguousarray(label)

        image = torch.from_numpy(image).float()
        label = torch.from_numpy(label).long()

        return {'image': image, 'label': label}


def transform(sample):
    trans = transforms.Compose([
        Pad(),
        # Random_rotate(),  # time-consuming
        Random_Crop(),
        Random_Flip(),
        Random_intencity_shift(),
        ToTensor()
    ])

    return trans(sample)


def transform_valid(sample):
    trans = transforms.Compose([
        Pad(),
        # MaxMinNormalization(),
        ToTensor()
    ])

    return trans(sample)


class BraTS(Dataset):
    def __init__(self, list_file, root='', mode='train'):
        self.lines = []
        paths, names = [], []
        with open(list_file) as f:
            for line in f:
                line = line.strip()
                name = line.split('/')[-1]
                names.append(name)
                path = os.path.join(root, line, name + '_')
                paths.append(path)
                self.lines.append(line)
        self.mode = mode
        self.names = names
        self.paths = paths

    def __getitem__(self, item):
        path = self.paths[item]
        if self.mode == 'train':
#             image, label = pkload(path + 'data_f32b0.pkl')
            label = np.array(nib_load(path + 'seg.nii'), dtype='uint8', order='C')
            images = np.stack([np.array(nib_load(path + modal + '.nii'), dtype='float32', order='C') for modal in modalities], -1)  # [240,240,155]

            output = path + 'data_f32b0.pkl'
            mask = images.sum(-1) > 0
            for k in range(4):
                x = images[..., k]  #
                y = x[mask]

                # 0.8885
                x[mask] -= y.mean()
                x[mask] /= y.std()

                images[..., k] = x

            sample = {'image': images, 'label': label}
            sample = transform(sample)
            return sample['image'], sample['label']
        elif self.mode == 'valid':
            image, label = pkload(path + 'data_f32b0.pkl')
            sample = {'image': image, 'label': label}
            sample = transform_valid(sample)
            return sample['image'], sample['label']
        else:
            image = pkload(path + 'data_f32b0.pkl')
            image = np.pad(image, ((0, 0), (0, 0), (0, 5), (0, 0)), mode='constant')
            image = np.ascontiguousarray(image.transpose(3, 0, 1, 2))
            image = torch.from_numpy(image).float()
            return image

    def __len__(self):
        return len(self.names)

    def collate(self, batch):
        return [torch.cat(v) for v in zip(*batch)]



## **IntmdSequential.py**

In [3]:
import torch.nn as nn


class IntermediateSequential(nn.Sequential):
    def __init__(self, *args, return_intermediate=True):
        super().__init__(*args)
        self.return_intermediate = return_intermediate

    def forward(self, input):
        if not self.return_intermediate:
            return super().forward(input)

        intermediate_outputs = {}
        output = input
        for name, module in self.named_children():
            output = intermediate_outputs[name] = module(output)

        return output, intermediate_outputs
        


## **Transformer.py**

In [4]:
import torch.nn as nn


class SelfAttention(nn.Module):
    def __init__(
        self, dim, heads=8, qkv_bias=False, qk_scale=None, dropout_rate=0.0
    ):
        super().__init__()
        self.num_heads = heads
        head_dim = dim // heads
        self.scale = qk_scale or head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(dropout_rate)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(dropout_rate)

    def forward(self, x):
        B, N, C = x.shape
        qkv = (
            self.qkv(x)
            .reshape(B, N, 3, self.num_heads, C // self.num_heads)
            .permute(2, 0, 3, 1, 4)
        )
        q, k, v = (
            qkv[0],
            qkv[1],
            qkv[2],
        )  # make torchscript happy (cannot use tensor as tuple)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x):
        return self.fn(x) + x


class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x):
        return self.fn(self.norm(x))


class PreNormDrop(nn.Module):
    def __init__(self, dim, dropout_rate, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.dropout = nn.Dropout(p=dropout_rate)
        self.fn = fn

    def forward(self, x):
        return self.dropout(self.fn(self.norm(x)))


class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout_rate):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(p=dropout_rate),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(p=dropout_rate),
        )

    def forward(self, x):
        return self.net(x)


class TransformerModel(nn.Module):
    def __init__(
        self,
        dim,
        depth,
        heads,
        mlp_dim,
        dropout_rate=0.1,
        attn_dropout_rate=0.1,
    ):
        super().__init__()
        layers = []
        for _ in range(depth):
            layers.extend(
                [
                    Residual(
                        PreNormDrop(
                            dim,
                            dropout_rate,
                            SelfAttention(dim, heads=heads, dropout_rate=attn_dropout_rate),
                        )
                    ),
                    Residual(
                        PreNorm(dim, FeedForward(dim, mlp_dim, dropout_rate))
                    ),
                ]
            )
            # dim = dim / 2
        self.net = IntermediateSequential(*layers)


    def forward(self, x):
        return self.net(x)

## **PositionalEncoding.py**

In [5]:
import torch
import torch.nn as nn

class FixedPositionalEncoding(nn.Module):
    def __init__(self, embedding_dim, max_length=512):
        super(FixedPositionalEncoding, self).__init__()

        pe = torch.zeros(max_length, embedding_dim)
        position = torch.arange(0, max_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, embedding_dim, 2).float()
            * (-torch.log(torch.tensor(10000.0)) / embedding_dim)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[: x.size(0), :]
        return x


class LearnedPositionalEncoding(nn.Module):
    def __init__(self, max_position_embeddings, embedding_dim, seq_length):
        super(LearnedPositionalEncoding, self).__init__()

        self.position_embeddings = nn.Parameter(torch.zeros(1, 4096, 512)) #8x

    def forward(self, x, position_ids=None):

        position_embeddings = self.position_embeddings
        return x + position_embeddings

## **Unet_skipconnection.py**

In [6]:
import torch.nn as nn
import torch.nn.functional as F
import torch

# adapt from https://github.com/MIC-DKFZ/BraTS2017


def normalization(planes, norm='gn'):
    if norm == 'bn':
        m = nn.BatchNorm3d(planes)
    elif norm == 'gn':
        m = nn.GroupNorm(8, planes)
    elif norm == 'in':
        m = nn.InstanceNorm3d(planes)
    else:
        raise ValueError('normalization type {} is not supported'.format(norm))
    return m



class InitConv(nn.Module):
    def __init__(self, in_channels=4, out_channels=16, dropout=0.2):
        super(InitConv, self).__init__()

        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1)
        self.dropout = dropout

    def forward(self, x):
        y = self.conv(x)
        y = F.dropout3d(y, self.dropout)

        return y


class EnBlock(nn.Module):
    def __init__(self, in_channels, norm='gn'):
        super(EnBlock, self).__init__()

        self.bn1 = normalization(in_channels, norm=norm)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1)

        self.bn2 = normalization(in_channels, norm=norm)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1)

    def forward(self, x):
        x1 = self.bn1(x)
        x1 = self.relu1(x1)
        x1 = self.conv1(x1)
        y = self.bn2(x1)
        y = self.relu2(y)
        y = self.conv2(y)
        y = y + x

        return y


class EnDown(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(EnDown, self).__init__()
        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=3, stride=2, padding=1)

    def forward(self, x):
        y = self.conv(x)

        return y



class Unet(nn.Module):
    def __init__(self, in_channels=4, base_channels=16, num_classes=4):
        super(Unet, self).__init__()

        self.InitConv = InitConv(in_channels=in_channels, out_channels=base_channels, dropout=0.2)
        self.EnBlock1 = EnBlock(in_channels=base_channels)
        self.EnDown1 = EnDown(in_channels=base_channels, out_channels=base_channels*2)

        self.EnBlock2_1 = EnBlock(in_channels=base_channels*2)
        self.EnBlock2_2 = EnBlock(in_channels=base_channels*2)
        self.EnDown2 = EnDown(in_channels=base_channels*2, out_channels=base_channels*4)

        self.EnBlock3_1 = EnBlock(in_channels=base_channels * 4)
        self.EnBlock3_2 = EnBlock(in_channels=base_channels * 4)
        self.EnDown3 = EnDown(in_channels=base_channels*4, out_channels=base_channels*8)

        self.EnBlock4_1 = EnBlock(in_channels=base_channels * 8)
        self.EnBlock4_2 = EnBlock(in_channels=base_channels * 8)
        self.EnBlock4_3 = EnBlock(in_channels=base_channels * 8)
        self.EnBlock4_4 = EnBlock(in_channels=base_channels * 8)

    def forward(self, x):
        x = self.InitConv(x)       # (1, 16, 128, 128, 128)

        x1_1 = self.EnBlock1(x)
        x1_2 = self.EnDown1(x1_1)  # (1, 32, 64, 64, 64)

        x2_1 = self.EnBlock2_1(x1_2)
        x2_1 = self.EnBlock2_2(x2_1)
        x2_2 = self.EnDown2(x2_1)  # (1, 64, 32, 32, 32)

        x3_1 = self.EnBlock3_1(x2_2)
        x3_1 = self.EnBlock3_2(x3_1)
        x3_2 = self.EnDown3(x3_1)  # (1, 128, 16, 16, 16)

        x4_1 = self.EnBlock4_1(x3_2)
        x4_2 = self.EnBlock4_2(x4_1)
        x4_3 = self.EnBlock4_3(x4_2)
        output = self.EnBlock4_4(x4_3)  # (1, 128, 16, 16, 16)

        return x1_1,x2_1,x3_1,output


# if __name__ == '__main__':
#     with torch.no_grad():
#         import os
#         os.environ['CUDA_VISIBLE_DEVICES'] = '0'
#         cuda0 = torch.device('cuda:0')
#         x = torch.rand((1, 4, 128, 128, 128), device=cuda0)
#         # model = Unet1(in_channels=4, base_channels=16, num_classes=4)
#         model = Unet(in_channels=4, base_channels=16, num_classes=4)
#         model.cuda()
#         output = model(x)
#         print('output:', output.shape)

## **TransBTS_downsample8x_skipconnection.py**

In [7]:
import torch
import torch.nn as nn


class TransformerBTS(nn.Module):
    def __init__(
        self,
        img_dim,
        patch_dim,
        num_channels,
        embedding_dim,
        num_heads,
        num_layers,
        hidden_dim,
        dropout_rate=0.0,
        attn_dropout_rate=0.0,
        conv_patch_representation=True,
        positional_encoding_type="learned",
    ):
        super(TransformerBTS, self).__init__()

        assert embedding_dim % num_heads == 0
        assert img_dim % patch_dim == 0

        self.img_dim = img_dim
        self.embedding_dim = embedding_dim
        self.num_heads = num_heads
        self.patch_dim = patch_dim
        self.num_channels = num_channels
        self.dropout_rate = dropout_rate
        self.attn_dropout_rate = attn_dropout_rate
        self.conv_patch_representation = conv_patch_representation

        self.num_patches = int((img_dim // patch_dim) ** 3)
        self.seq_length = self.num_patches
        self.flatten_dim = 128 * num_channels

        self.linear_encoding = nn.Linear(self.flatten_dim, self.embedding_dim)
        if positional_encoding_type == "learned":
            self.position_encoding = LearnedPositionalEncoding(
                self.seq_length, self.embedding_dim, self.seq_length
            )
        elif positional_encoding_type == "fixed":
            self.position_encoding = FixedPositionalEncoding(
                self.embedding_dim,
            )

        self.pe_dropout = nn.Dropout(p=self.dropout_rate)

        self.transformer = TransformerModel(
            embedding_dim,
            num_layers,
            num_heads,
            hidden_dim,

            self.dropout_rate,
            self.attn_dropout_rate,
        )
        self.pre_head_ln = nn.LayerNorm(embedding_dim)

        if self.conv_patch_representation:

            self.conv_x = nn.Conv3d(
                128,
                self.embedding_dim,
                kernel_size=3,
                stride=1,
                padding=1
            )

        self.Unet = Unet(in_channels=4, base_channels=16, num_classes=4)
        self.bn = nn.BatchNorm3d(128)
        self.relu = nn.ReLU(inplace=True)


    def encode(self, x):
        if self.conv_patch_representation:
            # combine embedding with conv patch distribution
            x1_1, x2_1, x3_1, x = self.Unet(x)
            x = self.bn(x)
            x = self.relu(x)
            x = self.conv_x(x)
            x = x.permute(0, 2, 3, 4, 1).contiguous()
            x = x.view(x.size(0), -1, self.embedding_dim)

        else:
            x = self.Unet(x)
            x = self.bn(x)
            x = self.relu(x)
            x = (
                x.unfold(2, 2, 2)
                .unfold(3, 2, 2)
                .unfold(4, 2, 2)
                .contiguous()
            )
            x = x.view(x.size(0), x.size(1), -1, 8)
            x = x.permute(0, 2, 3, 1).contiguous()
            x = x.view(x.size(0), -1, self.flatten_dim)
            x = self.linear_encoding(x)

        x = self.position_encoding(x)
        x = self.pe_dropout(x)

        # apply transformer
        x, intmd_x = self.transformer(x)
        x = self.pre_head_ln(x)

        return x1_1, x2_1, x3_1, x, intmd_x

    def decode(self, x):
        raise NotImplementedError("Should be implemented in child class!!")

    def forward(self, x, auxillary_output_layers=[1, 2, 3, 4]):

        x1_1, x2_1, x3_1, encoder_output, intmd_encoder_outputs = self.encode(x)

        decoder_output = self.decode(
            x1_1, x2_1, x3_1, encoder_output, intmd_encoder_outputs, auxillary_output_layers
        )

        if auxillary_output_layers is not None:
            auxillary_outputs = {}
            for i in auxillary_output_layers:
                val = str(2 * i - 1)
                _key = 'Z' + str(i)
                auxillary_outputs[_key] = intmd_encoder_outputs[val]

            return decoder_output

        return decoder_output

    def _get_padding(self, padding_type, kernel_size):
        assert padding_type in ['SAME', 'VALID']
        if padding_type == 'SAME':
            _list = [(k - 1) // 2 for k in kernel_size]
            return tuple(_list)
        return tuple(0 for _ in kernel_size)

    def _reshape_output(self, x):
        x = x.view(
            x.size(0),
            int(self.img_dim / self.patch_dim),
            int(self.img_dim / self.patch_dim),
            int(self.img_dim / self.patch_dim),
            self.embedding_dim,
        )
        x = x.permute(0, 4, 1, 2, 3).contiguous()

        return x


class BTS(TransformerBTS):
    def __init__(
        self,
        img_dim,
        patch_dim,
        num_channels,
        num_classes,
        embedding_dim,
        num_heads,
        num_layers,
        hidden_dim,
        dropout_rate=0.0,
        attn_dropout_rate=0.0,
        conv_patch_representation=True,
        positional_encoding_type="learned",
    ):
        super(BTS, self).__init__(
            img_dim=img_dim,
            patch_dim=patch_dim,
            num_channels=num_channels,
            embedding_dim=embedding_dim,
            num_heads=num_heads,
            num_layers=num_layers,
            hidden_dim=hidden_dim,
            dropout_rate=dropout_rate,
            attn_dropout_rate=attn_dropout_rate,
            conv_patch_representation=conv_patch_representation,
            positional_encoding_type=positional_encoding_type,
        )

        self.num_classes = num_classes

        self.Softmax = nn.Softmax(dim=1)

        self.Enblock8_1 = EnBlock1(in_channels=self.embedding_dim)
        self.Enblock8_2 = EnBlock2(in_channels=self.embedding_dim // 4)

        self.DeUp4 = DeUp_Cat(in_channels=self.embedding_dim//4, out_channels=self.embedding_dim//8)
        self.DeBlock4 = DeBlock(in_channels=self.embedding_dim//8)

        self.DeUp3 = DeUp_Cat(in_channels=self.embedding_dim//8, out_channels=self.embedding_dim//16)
        self.DeBlock3 = DeBlock(in_channels=self.embedding_dim//16)

        self.DeUp2 = DeUp_Cat(in_channels=self.embedding_dim//16, out_channels=self.embedding_dim//32)
        self.DeBlock2 = DeBlock(in_channels=self.embedding_dim//32)

        self.endconv = nn.Conv3d(self.embedding_dim // 32, 4, kernel_size=1)


    def decode(self, x1_1, x2_1, x3_1, x, intmd_x, intmd_layers=[1, 2, 3, 4]):

        assert intmd_layers is not None, "pass the intermediate layers for MLA"
        encoder_outputs = {}
        all_keys = []
        for i in intmd_layers:
            val = str(2 * i - 1)
            _key = 'Z' + str(i)
            all_keys.append(_key)
            encoder_outputs[_key] = intmd_x[val]
        all_keys.reverse()

        x8 = encoder_outputs[all_keys[0]]
        x8 = self._reshape_output(x8)
        x8 = self.Enblock8_1(x8)
        x8 = self.Enblock8_2(x8)

        y4 = self.DeUp4(x8, x3_1)  # (1, 64, 32, 32, 32)
        y4 = self.DeBlock4(y4)

        y3 = self.DeUp3(y4, x2_1)  # (1, 32, 64, 64, 64)
        y3 = self.DeBlock3(y3)

        y2 = self.DeUp2(y3, x1_1)  # (1, 16, 128, 128, 128)
        y2 = self.DeBlock2(y2)

        y = self.endconv(y2)      # (1, 4, 128, 128, 128)
        y = self.Softmax(y)
        return y

class EnBlock1(nn.Module):
    def __init__(self, in_channels):
        super(EnBlock1, self).__init__()

        self.bn1 = nn.BatchNorm3d(512 // 4)
        self.relu1 = nn.ReLU(inplace=True)
        self.bn2 = nn.BatchNorm3d(512 // 4)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv3d(in_channels, in_channels // 4, kernel_size=3, padding=1)
        self.conv2 = nn.Conv3d(in_channels // 4, in_channels // 4, kernel_size=3, padding=1)

    def forward(self, x):
        x1 = self.conv1(x)
        x1 = self.bn1(x1)
        x1 = self.relu1(x1)
        x1 = self.conv2(x1)
        x1 = self.bn2(x1)
        x1 = self.relu2(x1)

        return x1


class EnBlock2(nn.Module):
    def __init__(self, in_channels):
        super(EnBlock2, self).__init__()

        self.conv1 = nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm3d(512 // 4)
        self.relu1 = nn.ReLU(inplace=True)
        self.bn2 = nn.BatchNorm3d(512 // 4)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1)

    def forward(self, x):
        x1 = self.conv1(x)
        x1 = self.bn1(x1)
        x1 = self.relu1(x1)
        x1 = self.conv2(x1)
        x1 = self.bn2(x1)
        x1 = self.relu2(x1)
        x1 = x1 + x

        return x1


class DeUp_Cat(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DeUp_Cat, self).__init__()
        self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=1)
        self.conv2 = nn.ConvTranspose3d(out_channels, out_channels, kernel_size=2, stride=2)
        self.conv3 = nn.Conv3d(out_channels*2, out_channels, kernel_size=1)

    def forward(self, x, prev):
        x1 = self.conv1(x)
        y = self.conv2(x1)
        # y = y + prev
        y = torch.cat((prev, y), dim=1)
        y = self.conv3(y)
        return y

class DeBlock(nn.Module):
    def __init__(self, in_channels):
        super(DeBlock, self).__init__()

        self.bn1 = nn.BatchNorm3d(in_channels)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm3d(in_channels)
        self.relu2 = nn.ReLU(inplace=True)

    def forward(self, x):
        x1 = self.conv1(x)
        x1 = self.bn1(x1)
        x1 = self.relu1(x1)
        x1 = self.conv2(x1)
        x1 = self.bn2(x1)
        x1 = self.relu2(x1)
        x1 = x1 + x

        return x1




def TransBTS(dataset='brats', _conv_repr=True, _pe_type="learned"):

    if dataset.lower() == 'brats':
        img_dim = 128
        num_classes = 4

    num_channels = 4
    patch_dim = 8
    aux_layers = [1, 2, 3, 4]
    model = BTS(
        img_dim,
        patch_dim,
        num_channels,
        num_classes,
        embedding_dim=512,
        num_heads=8,
        num_layers=4,
        hidden_dim=4096,
        dropout_rate=0.1,
        attn_dropout_rate=0.1,
        conv_patch_representation=_conv_repr,
        positional_encoding_type=_pe_type,
    )

    return aux_layers, model


# if __name__ == '__main__':
#     with torch.no_grad():
#         import os
#         os.environ['CUDA_VISIBLE_DEVICES'] = '0'
#         cuda0 = torch.device('cuda:0')
#         x = torch.rand((1, 4, 128, 128, 128), device=cuda0)
#         _, model = TransBTS(dataset='brats', _conv_repr=True, _pe_type="learned")
#         model.cuda()
#         y = model(x)
#         print(y.shape)

## **criterions.py**

In [8]:
import torch
import logging
import torch.nn.functional as F
from torch.autograd import Variable

def expand_target(x, n_class,mode='softmax'):
    """
        Converts NxDxHxW label image to NxCxDxHxW, where each label is stored in a separate channel
        :param input: 4D input image (NxDxHxW)
        :param C: number of channels/labels
        :return: 5D output image (NxCxDxHxW)
        """
    assert x.dim() == 4
    shape = list(x.size())
    shape.insert(1, n_class)
    shape = tuple(shape)
    xx = torch.zeros(shape)
    if mode.lower() == 'softmax':
        xx[:, 1, :, :, :] = (x == 1)
        xx[:, 2, :, :, :] = (x == 2)
        xx[:, 3, :, :, :] = (x == 3)
    if mode.lower() == 'sigmoid':
        xx[:, 0, :, :, :] = (x == 1)
        xx[:, 1, :, :, :] = (x == 2)
        xx[:, 2, :, :, :] = (x == 3)
    return xx.to(x.device)

def flatten(tensor):
    """Flattens a given tensor such that the channel axis is first.
    The shapes are transformed as follows:
       (N, C, D, H, W) -> (C, N * D * H * W)
    """
    C = tensor.size(1)
    # new axis order
    axis_order = (1, 0) + tuple(range(2, tensor.dim()))
    # Transpose: (N, C, D, H, W) -> (C, N, D, H, W)
    transposed = tensor.permute(axis_order)
    # Flatten: (C, N, D, H, W) -> (C, N * D * H * W)
    return transposed.reshape(C, -1)

def Dice(output, target, eps=1e-5):
    target = target.float()
    num = 2 * (output * target).sum()
    den = output.sum() + target.sum() + eps
    return 1.0 - num/den


def softmax_dice(output, target):
    '''
    The dice loss for using softmax activation function
    :param output: (b, num_class, d, h, w)
    :param target: (b, d, h, w)
    :return: softmax dice loss
    '''
    loss1 = Dice(output[:, 1, ...], (target == 1).float())
    loss2 = Dice(output[:, 2, ...], (target == 2).float())
    loss3 = Dice(output[:, 3, ...], (target == 4).float())

    return loss1 + loss2 + loss3, 1-loss1.data, 1-loss2.data, 1-loss3.data


def softmax_dice2(output, target):
    '''
    The dice loss for using softmax activation function
    :param output: (b, num_class, d, h, w)
    :param target: (b, d, h, w)
    :return: softmax dice loss
    '''
    loss0 = Dice(output[:, 0, ...], (target == 0).float())
    loss1 = Dice(output[:, 1, ...], (target == 1).float())
    loss2 = Dice(output[:, 2, ...], (target == 2).float())
    loss3 = Dice(output[:, 3, ...], (target == 4).float())

    return loss1 + loss2 + loss3 + loss0, 1-loss1.data, 1-loss2.data, 1-loss3.data


def sigmoid_dice(output, target):
    '''
    The dice loss for using sigmoid activation function
    :param output: (b, num_class-1, d, h, w)
    :param target: (b, d, h, w)
    :return:
    '''
    loss1 = Dice(output[:, 0, ...], (target == 1).float())
    loss2 = Dice(output[:, 1, ...], (target == 2).float())
    loss3 = Dice(output[:, 2, ...], (target == 4).float())

    return loss1 + loss2 + loss3, 1-loss1.data, 1-loss2.data, 1-loss3.data


def Generalized_dice(output, target, eps=1e-5, weight_type='square'):
    if target.dim() == 4:  #(b, h, w, d)
        target[target == 4] = 3  #transfer label 4 to 3
        target = expand_target(target, n_class=output.size()[1])  #extend target from (b, h, w, d) to (b, c, h, w, d)

    output = flatten(output)[1:, ...]  # transpose [N,4，H,W,D] -> [4，N,H,W,D] -> [3, N*H*W*D] voxels
    target = flatten(target)[1:, ...]  # [class, N*H*W*D]

    target_sum = target.sum(-1)  # sub_class_voxels [3,1] -> 3个voxels
    if weight_type == 'square':
        class_weights = 1. / (target_sum * target_sum + eps)
    elif weight_type == 'identity':
        class_weights = 1. / (target_sum + eps)
    elif weight_type == 'sqrt':
        class_weights = 1. / (torch.sqrt(target_sum) + eps)
    else:
        raise ValueError('Check out the weight_type :', weight_type)

    # print(class_weights)
    intersect = (output * target).sum(-1)
    intersect_sum = (intersect * class_weights).sum()
    denominator = (output + target).sum(-1)
    denominator_sum = (denominator * class_weights).sum() + eps

    loss1 = 2*intersect[0] / (denominator[0] + eps)
    loss2 = 2*intersect[1] / (denominator[1] + eps)
    loss3 = 2*intersect[2] / (denominator[2] + eps)

    return 1 - 2. * intersect_sum / denominator_sum, loss1, loss2, loss3


def Dual_focal_loss(output, target):
    loss1 = Dice(output[:, 1, ...], (target == 1).float())
    loss2 = Dice(output[:, 2, ...], (target == 2).float())
    loss3 = Dice(output[:, 3, ...], (target == 4).float())
    
    if target.dim() == 4:  #(b, h, w, d)
        target[target == 4] = 3  #transfer label 4 to 3
        target = expand_target(target, n_class=output.size()[1])  #extend target from (b, h, w, d) to (b, c, h, w, d)

    target = target.permute(1, 0, 2, 3, 4).contiguous()
    output = output.permute(1, 0, 2, 3, 4).contiguous()
    target = target.view(4, -1)
    output = output.view(4, -1)
    log = 1-(target - output)**2

    return -(F.log_softmax((1-(target - output)**2), 0)).mean(), 1-loss1.data, 1-loss2.data, 1-loss3.data

## **tools.py**

In [9]:
import torch.distributed as dist


def all_reduce_tensor(tensor, op=dist.ReduceOp.SUM, world_size=1):
    tensor = tensor.clone()
    dist.all_reduce(tensor, op)
    tensor.div_(world_size)
    return tensor

In [10]:
def nib_load(file_name):
    if not os.path.exists(file_name):
        print('Invalid file name, can not find the file!')

    proxy = nib.load(file_name)
    data = proxy.get_fdata()
    proxy.uncache()
    return data

In [11]:
import os
os.environ['RANK'] = '0'
os.environ['WORLD_SIZE'] = '1' 
os.environ['MASTER_ADDR'] = '127.0.0.1'  # Set the address of the master process
os.environ['MASTER_PORT'] = '12345'
torch.distributed.init_process_group('nccl')

In [12]:
def append_to_file(data):
    # Open the file in append mode ('a+') or create it if it doesn't exist ('w+')
    with open('file_name.txt', 'a+') as file:
        # Move the cursor to the end of the file
        file.seek(0, 2)
        # Append the data to the file
        file.write(data + '\n')

In [13]:
!conda install -y gdown

Retrieving notices: ...working... done
done
Solving environment: done


  current version: 23.7.4
  latest version: 24.1.2

Please update conda by running

    $ conda update -n base -c conda-forge conda

Or to minimize the number of packages updated during conda update use

     conda install conda=24.1.2



## Package Plan ##

  environment location: /opt/conda

  added / updated specs:
    - gdown


The following packages will be downloaded:

    package                    |            build
    ---------------------------|-----------------
    filelock-3.13.1            |     pyhd8ed1ab_0          15 KB  conda-forge
    gdown-5.1.0                |     pyhd8ed1ab_0          21 KB  conda-forge
    ------------------------------------------------------------
                                           Total:          36 KB

The following NEW packages will be INSTALLED:

  filelock           conda-forge/noarch::filelock-3.13.1-pyhd8ed1ab_0 
  gdown              conda-forge/noarch::gdow

In [14]:
#!gdown --id 1-7DF6aVjeZEO_U27bF4NGxq2VB5zWs5I

In [15]:
!gdown https://www.kaggleusercontent.com/kf/165950140/eyJhbGciOiJkaXIiLCJlbmMiOiJBMTI4Q0JDLUhTMjU2In0..S5HktpP_IyxgmDf7_SMyvw.JDycMnZNcS5-PjMh6o1FN7aUS_Yn9XGeCb6eVymh5_QYUCrE7kQv-h2wrGs3SEz99DxJSebhLAL-sfEg8kCpGPWu2fLGL0_-s5Gxm7FMhgBDqmvZrXfrOddFkTOOTl5XCn2GQ0QRHolqgRwyj-meTb-ZnQaGL2yBrpvoSc-XP8KzXOGJ_X5_BqdYuNQJTmMUpHZfowtwqsa-EXy2g-LMAkzZnbZPK0GDicrzTL6FYSU5nCQMF7L4BPZ4KIs4iNHH67a5CSExaAktEpamLS7-jdz8Koyy8MoaVsuzU2fK-KXv4A7CgTUmgEEIxO-wunmI79JkOr-wU3b8ZxSxNRy-n1aqUzlsncjZP8vSjnoIHKMu35jO2xegriyTm9x8mfllllLATTZsSmvSze1GBPw7_LLUsaZKVrhEgR_jqR42_1TldQl0E62vljCZt8GslXshz3ue36XMx9nhSH5Ty26onO42xV1d0Wid-6PdpLcEXw2ZASJZ2gVwXYkgttoF3DR2AVnqs1EkxgawM1LlvB8Jy6orwWnRl3f35Jh5X9TjIzd7lqiGZ48ZXddxxYj77uJ70jip_cFSXz8efBSQI4vajozWip3L-1XhHRETqWSkbF5S40LVdMWj-NbVCCwd0J8aeA_Un0Mix3YyBbMBuez6eeT9r4--L5J0H1A0b-okS8Q.nxv6vXjfrHVfx7eYcu-BQA/checkpoint/TransBTS2024-03-08/model_epoch_last.pth

Downloading...
From: https://www.kaggleusercontent.com/kf/165950140/eyJhbGciOiJkaXIiLCJlbmMiOiJBMTI4Q0JDLUhTMjU2In0..S5HktpP_IyxgmDf7_SMyvw.JDycMnZNcS5-PjMh6o1FN7aUS_Yn9XGeCb6eVymh5_QYUCrE7kQv-h2wrGs3SEz99DxJSebhLAL-sfEg8kCpGPWu2fLGL0_-s5Gxm7FMhgBDqmvZrXfrOddFkTOOTl5XCn2GQ0QRHolqgRwyj-meTb-ZnQaGL2yBrpvoSc-XP8KzXOGJ_X5_BqdYuNQJTmMUpHZfowtwqsa-EXy2g-LMAkzZnbZPK0GDicrzTL6FYSU5nCQMF7L4BPZ4KIs4iNHH67a5CSExaAktEpamLS7-jdz8Koyy8MoaVsuzU2fK-KXv4A7CgTUmgEEIxO-wunmI79JkOr-wU3b8ZxSxNRy-n1aqUzlsncjZP8vSjnoIHKMu35jO2xegriyTm9x8mfllllLATTZsSmvSze1GBPw7_LLUsaZKVrhEgR_jqR42_1TldQl0E62vljCZt8GslXshz3ue36XMx9nhSH5Ty26onO42xV1d0Wid-6PdpLcEXw2ZASJZ2gVwXYkgttoF3DR2AVnqs1EkxgawM1LlvB8Jy6orwWnRl3f35Jh5X9TjIzd7lqiGZ48ZXddxxYj77uJ70jip_cFSXz8efBSQI4vajozWip3L-1XhHRETqWSkbF5S40LVdMWj-NbVCCwd0J8aeA_Un0Mix3YyBbMBuez6eeT9r4--L5J0H1A0b-okS8Q.nxv6vXjfrHVfx7eYcu-BQA/checkpoint/TransBTS2024-03-08/model_epoch_last.pth
To: /kaggle/working/model_epoch_last.pth
100%|████████████████████████████████████████| 525M/525M [00:

In [16]:
#!pip install monai==1.2.0

In [17]:
#from monai.losses.dice import DiceLoss

In [None]:
import os
import random
import logging
import numpy as np
import time
import setproctitle

import torch
import torch.backends.cudnn as cudnn
import torch.optim
import torch.distributed as dist
import nibabel as nib


from torch.utils.data import DataLoader
from tensorboardX import SummaryWriter
from torch import nn

from sklearn.model_selection import train_test_split
from torch.utils.data import Subset


local_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())

# Default values for arguments
args = {
    'user': 'name of user',
    'experiment': 'TransBTS',
    'date': local_time.split(' ')[0],
    'description': 'TransBTS, training on train.txt!',
    'root': 'path to training set',
    'train_dir': 'Train',
    'valid_dir': 'Valid',
    'mode': 'train',
    'train_file': 'train.txt',
    'valid_file': 'valid.txt',
    'dataset': 'brats',
    'model_name': 'TransBTS',
    'input_C': 4,
    'input_H': 240,
    'input_W': 240,
    'input_D': 160,
    'crop_H': 128,
    'crop_W': 128,
    'crop_D': 128,
    'output_D': 155,
    'lr': 0.0002,
    'weight_decay': 1e-5,
    'amsgrad': True,
    'criterion': 'softmax_dice',
    'num_class': 4,
    'seed': 1000,
    'no_cuda': False,
    'gpu': '0,1,2,3',
    'num_workers': 8,
    'batch_size': 1,
    'start_epoch': 50,
    'end_epoch': 80,
    'save_freq': 5,
    'resume': '',
    'load': True,
    'local_rank': 0
}

def main_worker():
    if args['local_rank'] == 0:
        __file__ = "./log"

        # Check if the path exists
        if not os.path.exists(__file__):
            # If it doesn't exist, create it recursively
            os.makedirs(__file__)
            print(f"Path '{__file__}' created successfully.")
        else:
            print(f"Path '{__file__}' already exists.")
        log_dir = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'log', args['experiment']+args['date'])
        log_file = log_dir + '.txt'
        log_args(log_file)
        logging.info('--------------------------------------This is all argsurations----------------------------------')
        for arg in args:
            logging.info('{}={}'.format(arg, args[arg]))
        logging.info('----------------------------------------This is a halving line----------------------------------')
        logging.info('{}'.format(args['description']))

    torch.manual_seed(args['seed'])
    torch.cuda.manual_seed(args['seed'])
    random.seed(args['seed'])
    np.random.seed(args['seed'])
#     torch.distributed.init_process_group('nccl')
    torch.cuda.set_device(args['local_rank'])

    _, model = TransBTS(dataset='brats', _conv_repr=True, _pe_type="learned")

    model.cuda(args['local_rank'])
    model = nn.parallel.DistributedDataParallel(model, device_ids=[args['local_rank']], output_device=args['local_rank'],
                                                find_unused_parameters=True)
    model.train()

    optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'], weight_decay=args['weight_decay'], amsgrad=args['amsgrad'])
    criterion = softmax_dice

    if args['local_rank'] == 0:
        checkpoint_dir = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'checkpoint', args['experiment']+args['date'])
        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)

    resume = '/kaggle/working/model_epoch_last.pth'

    writer = SummaryWriter()

    if os.path.isfile(resume) and args['load']:
        logging.info('loading checkpoint {}'.format(resume))
        checkpoint = torch.load(resume, map_location=lambda storage, loc: storage)

        model.load_state_dict(checkpoint['state_dict'])

        logging.info('Successfully loading checkpoint {} and training from epoch: {}'
                     .format(args['resume'], args['start_epoch']))
    else:
        logging.info('re-training!!!')

    train_list = os.path.join('/kaggle/input/filestxt/train.txt')
    train_root = os.path.join('/kaggle/input/brats20-dataset-training-validation/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData')
    train_set = BraTS(train_list, train_root, args['mode'])
    
    ##############################
    
    # Get list of all data indices
    all_indices = list(range(len(train_set)))

     # Split indices into training and validation sets
    train_indices, valid_indices = train_test_split(all_indices, test_size=0.2, random_state=args['seed'])

     
    # Create Subset datasets from train_set using the split indices
    train_subset = Subset(train_set, train_indices)
    valid_subset = Subset(train_set, valid_indices)

   # Define samplers for distributed training
    train_sampler = torch.utils.data.distributed.DistributedSampler(train_subset)
    valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_subset)

    # Create data loaders for training and validation
    train_loader = DataLoader(dataset=train_subset, sampler=train_sampler, batch_size=1,
                          drop_last=True, num_workers=args['num_workers'], pin_memory=True)
    valid_loader = DataLoader(dataset=valid_subset, sampler=valid_sampler, batch_size=1,
                          drop_last=True, num_workers=args['num_workers'], pin_memory=True)

  #############################
    print("train indices :" ,len(train_loader),"  valid indices :" ,len(valid_loader) )
    
    
    #train_sampler = torch.utils.data.distributed.DistributedSampler(train_set)
    num = len(train_set)
    logging.info('Samples for train = {}'.format(len(train_set)))


    num_gpu = (len(args['gpu'])+1) // 2

    #train_loader = DataLoader(dataset=train_set, sampler=train_sampler, batch_size=1,drop_last=True, num_workers=args['num_workers'], pin_memory=True)

    start_time = time.time()

    torch.set_grad_enabled(True)
    
    

    for epoch in range(args['start_epoch'], args['end_epoch']):
        train_sampler.set_epoch(epoch)  # shuffle
        setproctitle.setproctitle('{}: {}/{}'.format(args['user'], epoch+1, args['end_epoch']))
        start_epoch = time.time()
        total= 0
        total1 =0
        total2=0
        total3 = 0
        
        #training loop
        for i, data in enumerate(train_loader):

            adjust_learning_rate(optimizer, epoch, args['end_epoch'], args['lr'])

            x, target = data
            x = x.cuda(args['local_rank'], non_blocking=True)
            target = target.cuda(args['local_rank'], non_blocking=True)


            output = model(x)
    
            loss, loss1, loss2, loss3 = criterion(output, target)

            reduce_loss = all_reduce_tensor(loss, world_size=num_gpu).data.cpu().numpy()
            reduce_loss1 = all_reduce_tensor(loss1, world_size=num_gpu).data.cpu().numpy()
            reduce_loss2 = all_reduce_tensor(loss2, world_size=num_gpu).data.cpu().numpy()
            reduce_loss3 = all_reduce_tensor(loss3, world_size=num_gpu).data.cpu().numpy()
            total += reduce_loss 
            total1 += reduce_loss1
            total2 += reduce_loss2
            total3 += reduce_loss3
            
            
            if args['local_rank'] == 0:
                logging.info('Epoch: {}_Iter:{}  loss: {:.5f} || 1:{:.4f} | 2:{:.4f} | 3:{:.4f} ||'
                             .format(epoch, i, reduce_loss, reduce_loss1, reduce_loss2, reduce_loss3))
                             

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        #validation loop    
        model.eval()  # Set model to evaluation mode
        
        total_val_loss = 0
        total_val_loss1 = 0
        total_val_loss2 = 0
        total_val_loss3 = 0
            
        total_val_batches = 0

        for i, data in enumerate(valid_loader):
                # Similar to training loop, compute and accumulate validation loss
                x_val, target_val = data
                x_val = x_val.cuda(args['local_rank'], non_blocking=True)
                target_val = target_val.cuda(args['local_rank'], non_blocking=True)

                output_val = model(x_val)
                loss_val,loss_val1, loss_val2,loss_val3 = criterion(output_val, target_val)
                reduce_loss_val = all_reduce_tensor(loss_val, world_size=num_gpu).data.cpu().numpy()
                reduce_loss_val1 = all_reduce_tensor(loss_val1, world_size=num_gpu).data.cpu().numpy()
                reduce_loss_val2 = all_reduce_tensor(loss_val2, world_size=num_gpu).data.cpu().numpy()
                reduce_loss_val3 = all_reduce_tensor(loss_val3, world_size=num_gpu).data.cpu().numpy()
                
                total_val_loss += reduce_loss_val
                total_val_loss1 += reduce_loss_val1
                total_val_loss2 += reduce_loss_val2
                total_val_loss3 += reduce_loss_val3
                total_val_batches += 1

            # Compute average validation loss
        average_val_loss = total_val_loss / total_val_batches
        average_val_loss1 = total_val_loss1 / total_val_batches
        average_val_loss2 = total_val_loss2 / total_val_batches
        average_val_loss3 = total_val_loss3 / total_val_batches
        
        del x_val, target_val
          
        # Print average validation loss
        if args['local_rank'] == 0:
            logging.info('Epoch: {} Validation Loss: {:.5f}|| 1:{:.4f} | 2:{:.4f} | 3:{:.4f} ||'.format(epoch, average_val_loss,average_val_loss1, average_val_loss2,average_val_loss3))
        

        model.train()  # Set model back to training mode

        

        end_epoch = time.time()
        total,total1,total2,total3
        append_to_file('Epoch: {} Training loss: {:.5f} || 1:{:.4f} | 2:{:.4f} | 3:{:.4f} ||'
                             .format(epoch, total/num, total1/num, total2/num, total3 / num) )
        append_to_file('Epoch: {} Validation Loss: {:.5f}|| 1:{:.4f} | 2:{:.4f} | 3:{:.4f} ||'
                       .format(epoch, average_val_loss,average_val_loss1, average_val_loss2,average_val_loss3))
        
        
        
        
        if args['local_rank'] == 0:
            if (epoch + 1) % int(args['save_freq']) == 0 \
                    or (epoch + 1) % int(args['end_epoch'] - 1) == 0 \
                    or (epoch + 1) % int(args['end_epoch'] - 2) == 0 \
                    or (epoch + 1) % int(args['end_epoch'] - 3) == 0:
                file_name = os.path.join(checkpoint_dir, 'model_epoch_{}.pth'.format(epoch))
                torch.save({
                    'epoch': epoch,
                    'state_dict': model.state_dict(),
                    'optim_dict': optimizer.state_dict(),
                },
                    file_name)

            writer.add_scalar('lr:', optimizer.param_groups[0]['lr'], epoch)
            writer.add_scalar('loss:', reduce_loss, epoch)
            writer.add_scalar('loss1:', reduce_loss1, epoch)
            writer.add_scalar('loss2:', reduce_loss2, epoch)
            writer.add_scalar('loss3:', reduce_loss3, epoch)
          

        if args['local_rank'] == 0:
            epoch_time_minute = (end_epoch-start_epoch)/60
            remaining_time_hour = (args['end_epoch']-epoch-1)*epoch_time_minute/60
            logging.info('Current epoch time consumption: {:.2f} minutes!'.format(epoch_time_minute))
            logging.info('Estimated remaining training time: {:.2f} hours!'.format(remaining_time_hour))
        if args['local_rank'] == 0:
            writer.close()

            final_name = os.path.join(checkpoint_dir, 'model_epoch_last.pth')
            torch.save({
            'epoch': args['end_epoch'],
            'state_dict': model.state_dict(),
            'optim_dict': optimizer.state_dict(),
        },
            final_name)
    end_time = time.time()
    total_time = (end_time-start_time)/3600
    logging.info('The total training time is {:.2f} hours'.format(total_time))

    logging.info('----------------------------------The training process finished!-----------------------------------')


def adjust_learning_rate(optimizer, epoch, max_epoch, init_lr, power=0.9):
    for param_group in optimizer.param_groups:
        param_group['lr'] = round(init_lr * np.power(1-(epoch) / max_epoch, power), 8)


def log_args(log_file):

    logger = logging.getLogger()
    logger.setLevel(logging.DEBUG)
    formatter = logging.Formatter(
        '%(asctime)s ===> %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S')

    # args FileHandler to save log file
    fh = logging.FileHandler(log_file)
    fh.setLevel(logging.DEBUG)
    fh.setFormatter(formatter)

    # args StreamHandler to print log to console
    ch = logging.StreamHandler()
    ch.setLevel(logging.DEBUG)
    ch.setFormatter(formatter)

    # add the two Handler
    logger.addHandler(ch)
    logger.addHandler(fh)

# Run the main worker
main_worker()
     

2024-03-08 12:42:55 ===> --------------------------------------This is all argsurations----------------------------------
2024-03-08 12:42:55 ===> --------------------------------------This is all argsurations----------------------------------
2024-03-08 12:42:55 ===> user=name of user
2024-03-08 12:42:55 ===> user=name of user
2024-03-08 12:42:55 ===> experiment=TransBTS
2024-03-08 12:42:55 ===> experiment=TransBTS
2024-03-08 12:42:55 ===> date=2024-03-08
2024-03-08 12:42:55 ===> date=2024-03-08
2024-03-08 12:42:55 ===> description=TransBTS, training on train.txt!
2024-03-08 12:42:55 ===> description=TransBTS, training on train.txt!
2024-03-08 12:42:55 ===> root=path to training set
2024-03-08 12:42:55 ===> root=path to training set
2024-03-08 12:42:55 ===> train_dir=Train
2024-03-08 12:42:55 ===> train_dir=Train
2024-03-08 12:42:55 ===> valid_dir=Valid
2024-03-08 12:42:55 ===> valid_dir=Valid
2024-03-08 12:42:55 ===> mode=train
2024-03-08 12:42:55 ===> mode=train
2024-03-08 12:42:55 

Path './log' already exists.


2024-03-08 12:42:55 ===> loading checkpoint /kaggle/working/model_epoch_last.pth
2024-03-08 12:42:55 ===> loading checkpoint /kaggle/working/model_epoch_last.pth
2024-03-08 12:42:56 ===> Successfully loading checkpoint  and training from epoch: 50
2024-03-08 12:42:56 ===> Successfully loading checkpoint  and training from epoch: 50
2024-03-08 12:42:56 ===> Samples for train = 368
2024-03-08 12:42:56 ===> Samples for train = 368


train indices : 294   valid indices : 74


2024-03-08 12:43:02 ===> Epoch: 50_Iter:0  loss: 0.45499 || 1:0.0022 | 2:0.1031 | 3:0.1898 ||
2024-03-08 12:43:02 ===> Epoch: 50_Iter:0  loss: 0.45499 || 1:0.0022 | 2:0.1031 | 3:0.1898 ||
2024-03-08 12:43:04 ===> Epoch: 50_Iter:1  loss: 0.31879 || 1:0.0000 | 2:0.2226 | 3:0.2086 ||
2024-03-08 12:43:04 ===> Epoch: 50_Iter:1  loss: 0.31879 || 1:0.0000 | 2:0.2226 | 3:0.2086 ||
2024-03-08 12:43:05 ===> Epoch: 50_Iter:2  loss: 0.39524 || 1:0.0000 | 2:0.1631 | 3:0.1917 ||
2024-03-08 12:43:05 ===> Epoch: 50_Iter:2  loss: 0.39524 || 1:0.0000 | 2:0.1631 | 3:0.1917 ||
2024-03-08 12:43:06 ===> Epoch: 50_Iter:3  loss: 0.43352 || 1:0.0017 | 2:0.1843 | 3:0.1305 ||
2024-03-08 12:43:06 ===> Epoch: 50_Iter:3  loss: 0.43352 || 1:0.0017 | 2:0.1843 | 3:0.1305 ||
2024-03-08 12:43:07 ===> Epoch: 50_Iter:4  loss: 0.33626 || 1:0.0016 | 2:0.2096 | 3:0.2026 ||
2024-03-08 12:43:07 ===> Epoch: 50_Iter:4  loss: 0.33626 || 1:0.0016 | 2:0.2096 | 3:0.2026 ||
2024-03-08 12:43:08 ===> Epoch: 50_Iter:5  loss: 0.69109 || 

In [None]:
from IPython.display import FileLink

FileLink(r'checkpoint/TransBTS2024-02-23/model_epoch_last.pth')

In [None]:
for i, left in enumerate(dataloader):
    print(i)
    with torch.no_grad():
        temp = model(left).view(-1, 1, 300, 300)
    right.append(temp.to('cpu'))
    del temp
    torch.cuda.empty_cache()