In [3]:
!pip install torchio

Collecting torchio
  Using cached torchio-0.20.2-py3-none-any.whl.metadata (50 kB)
Using cached torchio-0.20.2-py3-none-any.whl (175 kB)
Installing collected packages: torchio
Successfully installed torchio-0.20.2


In [4]:
import os, glob, nibabel, random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torch.optim as optim
from torch import Tensor
import torchio as tio
import numpy as np
import pandas as pd
import matplotlib as plt

In [5]:
def double_conv(in_channel, out_channel):
    conv = nn.Sequential(
        nn.Conv2d(in_channel, out_channel, kernel_size=3),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channel, out_channel, kernel_size=3),
        nn.ReLU(inplace=True)
    )
    return conv

def crop_img(tensor, target_tensor):
    target_size = target_tensor.size()[2]
    tensor_size = tensor.size()[2]
    delta = (tensor_size - target_size) // 2
    return tensor[:, :, delta:tensor_size - delta, delta:tensor_size - delta]

In [6]:
# checking devices available
if torch.cuda.is_available():
    device = torch.device('cuda:0')
else:
    device = torch.device('cpu')
print(device)

cpu


In [7]:
#Data path
root_path = '/kaggle/input/brats2023-ssa-training-dataset/ASNR-MICCAI-BraTS2023-SSA-Challenge-TrainingData_V2/BraTS-SSA-00007-000/'

In [72]:

class Train_dataset(Dataset):
    def __init__(self, img_path, transform):
        self.img_path = img_path
        self.transform = transform 

    def __len__(self):
        return len(self.img_path)

    def __getitem__(self, index):
        img_path = glob.glob(self.img_path)[index]
        
        mask_path = glob.glob(os.path.join(img_path, '**', '*seg.nii'), recursive=True)[0]
        t1_path = glob.glob(os.path.join(img_path, '**', '*t1n.nii'), recursive=True)[0]
        t2_path = glob.glob(os.path.join(img_path, '**', '*t2w.nii'), recursive=True)[0]
        print(mask_path)
        t1 = nibabel.load(t1_path)
        t1_data = t1.get_fdata()
        
        t2 = nibabel.load(t2_path)
        t2_data = t2.get_fdata()

        mask_img = nibabel.load(mask_path)
        mask_data = t1.get_fdata()

        if self.transform:
            t1_data = self.transform(t1_data)
            t2_data = self.transform(1, t2_data)
            mask_data = self.transform(1, mask_data)
            
        
        image = np.stack([t1_data, t2_data], axis=0)

        return torch.tensor(image, dtype=torch.float32), torch.tensor(mask_data, dtype=torch.float32)

In [73]:
transform = tio.Compose([
    tio.RandomAffine(),
    tio.RandomElasticDeformation(),
    tio.RescaleIntensity((0, 1))
])

dataset = Train_dataset(root_path, transform)
dataloader = DataLoader(dataset, batch_size=10, num_workers=2, shuffle=True)

random_img, random_mask = dataset[0]

/kaggle/input/brats2023-ssa-training-dataset/ASNR-MICCAI-BraTS2023-SSA-Challenge-TrainingData_V2/BraTS-SSA-00007-000/BraTS-SSA-00007-000-seg.nii


ValueError: The input must be a 4D tensor with dimensions (channels, x, y, z) but it has shape (240, 240, 155). Tips: if it is a volume, please add the channels dimension; if it is 2D, also add a dimension of size 1 for the z axis

In [None]:
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()

        self.max_pool_2x2 = nn.MaxPool2d(stride=2, kernel_size=2)
        self.down_conv_1 = double_conv(1, 64)
        self.down_conv_2 = double_conv(64, 128)
        self.down_conv_3 = double_conv(128, 256)
        self.down_conv_4 = double_conv(256, 512)
        self.down_conv_5 = double_conv(512, 1024)

        self.up_trans_1 = nn.ConvTranspose2d(
            in_channels = 1024,
            out_channels = 512,
            stride = 2,
            kernel_size = 2
        )
        self.up_conv_1 = double_conv(1024, 512)

        self.up_trans_2 = nn.ConvTranspose2d(
            in_channels = 512,
            out_channels = 256,
            stride = 2,
            kernel_size = 2
        )
        self.up_conv_2 = double_conv(512, 256)

        self.up_trans_3 = nn.ConvTranspose2d(
            in_channels = 256,
            out_channels = 128,
            stride = 2,
            kernel_size = 2
        )
        self.up_conv_3 = double_conv(256, 128)

        self.up_trans_4 = nn.ConvTranspose2d(
            in_channels = 128,
            out_channels = 64,
            stride = 2,
            kernel_size = 2
        )
        self.up_conv_4 = double_conv(128, 64)
        
        self.output = nn.Conv2d(
            in_channels = 64,
            out_channels = 2,
            kernel_size = 1
        )

    def forward(self, image):
        #encoder
        x1 = self.down_conv_1(image)
        x2 = self.max_pool_2x2(x1)
        x3 = self.down_conv_2(x2)
        x4 = self.max_pool_2x2(x3)
        x5 = self.down_conv_3(x4)
        x6 = self.max_pool_2x2(x5)
        x7 = self.down_conv_4(x6)
        x8 = self.max_pool_2x2(x7)
        x9 = self.down_conv_5(x8)

        x = self.up_trans_1(x9)
        y = crop_img(x7, x)
        x = self.up_conv_1(torch.cat([x, y], 1))

        x = self.up_trans_2(x)
        y = crop_img(x5, x)
        x = self.up_conv_2(torch.cat([x, y], 1))

        x = self.up_trans_3(x)
        y = crop_img(x3, x)
        x = self.up_conv_3(torch.cat([x, y], 1))

        x = self.up_trans_4(x)
        y = crop_img(x1, x)
        x = self.up_conv_4(torch.cat([x, y], 1))

        output = self.output(x)
        return output


if __name__ == '__main__':
    model = UNet()
    image = torch.rand([1, 1, 572, 572])
    model.forward(image)