In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
import torch
import nibabel as nib
import numpy as np
import os
import glob
import torchvision.transforms as transforms

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [6]:
class BTSDataset(torch.utils.data.Dataset):
    def __init__(self, patient_data_list, no_classes=4):
        self.patient_flair_scans_list = [glob.glob(os.path.join(i, "*_flair.nii.gz"))[0] for i in patient_data_list] 
        print(self.patient_flair_scans_list)
        self.patient_t1ce_scans_list = [glob.glob(os.path.join(i, "*_t1ce.nii.gz"))[0] for i in patient_data_list] 
        print(self.patient_t1ce_scans_list)
        self.patient_t2_scans_list = [glob.glob(os.path.join(i, "*_t2.nii.gz"))[0] for i in patient_data_list] 
        print(self.patient_t2_scans_list)
        self.patient_seg_scans_list = [glob.glob(os.path.join(i, "*_seg.nii.gz"))[0] for i in patient_data_list] 
        print(self.patient_seg_scans_list)
        self.transform = transforms.ToTensor()
        self.no_classes = no_classes
    
    def __len__(self):
        print(len(self.patient_flair_scans_list))
        return len(self.patient_flair_scans_list)
    
    def __getitem__(self, idx):
        t1ce_scan = self.transform(np.asarray(nib.load(self.patient_t1ce_scans_list[idx]).get_fdata()))
        print(t1ce_scan)
        t2_scan = self.transform(np.asarray(nib.load(self.patient_t2_scans_list[idx]).get_fdata()))
        print(t2_scan)
        flair_scan = self.transform(np.asarray(nib.load(self.patient_flair_scans_list[idx]).get_fdata()))
        print(flair_scan)
        seg_label = np.asarray(nib.load(self.patient_seg_scans_list[idx]).get_fdata())
        print(seg_label)
        seg_label[seg_label == 4] = 3 #shape is (240, 240, 155)
        seg_label = self.transform(seg_label)
        seg_label_ohe = torch.nn.functional.one_hot(seg_label.to(torch.int64), self.no_classes) 
        seg_label_ohe = torch.moveaxis(seg_label_ohe, -1, 0) 
        image_scans_stacked = torch.stack([t1ce_scan, t2_scan, flair_scan])
        return image_scans_stacked, seg_label_ohe   #(torch.Size([3, 155, 240, 240]), torch.Size([4, 155, 240, 240]))     

In [7]:
path = "/content/drive/MyDrive/brats_train_10_data"
folder_list = sorted(glob.glob(os.path.join(path, "*")))
folder_list = [i for i in folder_list if os.path.isdir(i)]

In [8]:
dataset = BTSDataset(folder_list)

['/content/drive/MyDrive/brats_train_10_data/BraTS2021_00000/BraTS2021_00000_flair.nii.gz', '/content/drive/MyDrive/brats_train_10_data/BraTS2021_00002/BraTS2021_00002_flair.nii.gz', '/content/drive/MyDrive/brats_train_10_data/BraTS2021_00003/BraTS2021_00003_flair.nii.gz', '/content/drive/MyDrive/brats_train_10_data/BraTS2021_00005/BraTS2021_00005_flair.nii.gz', '/content/drive/MyDrive/brats_train_10_data/BraTS2021_00006/BraTS2021_00006_flair.nii.gz', '/content/drive/MyDrive/brats_train_10_data/BraTS2021_00008/BraTS2021_00008_flair.nii.gz', '/content/drive/MyDrive/brats_train_10_data/BraTS2021_00009/BraTS2021_00009_flair.nii.gz', '/content/drive/MyDrive/brats_train_10_data/BraTS2021_00011/BraTS2021_00011_flair.nii.gz', '/content/drive/MyDrive/brats_train_10_data/BraTS2021_00012/BraTS2021_00012_flair.nii.gz', '/content/drive/MyDrive/brats_train_10_data/BraTS2021_00014/BraTS2021_00014_flair.nii.gz']
['/content/drive/MyDrive/brats_train_10_data/BraTS2021_00000/BraTS2021_00000_t1ce.nii.gz'

In [9]:
len(folder_list)

10

In [10]:
ex_img, ex_label = dataset[0]
ex_img.shape, ex_label.shape

tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        ...,

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0., 

(torch.Size([3, 155, 240, 240]), torch.Size([4, 155, 240, 240]))

In [11]:
dataloader = torch.utils.data.DataLoader(dataset, batch_size=10,
                                                      shuffle=True)

10
10


In [12]:
import torch
import torch.nn as nn

In [12]:
# class VNet(nn.Module):
#     def __init__(self):
#         super(VNet, self).__init__()
#         self.conv_block1 = nn.Sequential(
#             nn.Conv3d(3, 16, kernel_size=5, stride=1, padding=2),
#             nn.BatchNorm3d(16),
#             nn.ReLU(),
#             nn.Conv3d(16, 16, kernel_size=5, stride=1, padding=2),
#             nn.BatchNorm3d(16),
#             nn.ReLU()
#         )
#         self.downsample1 = nn.Conv3d(16, 32, kernel_size=2, stride=2, padding=0)
#         self.conv_block2 = nn.Sequential(
#             nn.Conv3d(32, 32, kernel_size=5, stride=1, padding=2),
#             nn.BatchNorm3d(32),
#             nn.ReLU(),
#             nn.Conv3d(32, 32, kernel_size=5, stride=1, padding=2),
#             nn.BatchNorm3d(32),
#             nn.ReLU()
#         )
#         self.downsample2 = nn.Conv3d(32, 64, kernel_size=2, stride=2, padding=0)
#         self.conv_block3 = nn.Sequential(
#             nn.Conv3d(64, 64, kernel_size=5, stride=1, padding=2),
#             nn.BatchNorm3d(64),
#             nn.ReLU(),
#             nn.Conv3d(64, 64, kernel_size=5, stride=1, padding=2),
#             nn.BatchNorm3d(64),
#             nn.ReLU()
#         )
#         self.upsample1 = nn.ConvTranspose3d(64, 32, kernel_size=2, stride=2, padding=0)
#         self.conv_block4 = nn.Sequential(
#             nn.Conv3d(64, 32, kernel_size=5, stride=1, padding=2),
#             nn.BatchNorm3d(32),
#             nn.ReLU(),
#             nn.Conv3d(32, 32, kernel_size=5, stride=1, padding=2),
#             nn.BatchNorm3d(32),
#             nn.ReLU()
#         )
#         self.upsample2 = nn.ConvTranspose3d(32, 16, kernel_size=2, stride=2, padding=0)
#         self.conv_block5 = nn.Sequential(
#             nn.Conv3d(32, 16, kernel_size=5, stride=1, padding=2),
#             nn.BatchNorm3d(16),
#             nn.ReLU(),
#             nn.Conv3d(16, 16, kernel_size=5, stride=1, padding=2),
#             nn.BatchNorm3d(16),
#             nn.ReLU()
#         )
#         self.output = nn.Conv3d(16, 4, kernel_size=1, stride=1, padding=0)

#     def forward(self, x):
#         x = self.conv_block1(x)
#         x = self.downsample1(x)
#         x = self.conv_block2(x)
#         x = self.downsample2(x)
#         x = self.conv_block3(x)
#         x = self.upsample1(x)
#         x = torch.cat([x, self.conv_block2(self.downsample1(self.conv_block1(x)))], dim=1)
#         x = self.conv_block4(x)
#         x = self.upsample2(x)
#         x = torch.cat([x, self.conv_block1(x)], dim=1)
#         x = self.conv_block5(x)
#         x = self.output(x)
#         return x

In [13]:
# model = VNet()

In [None]:
# import torch.optim as optim

# optimizer = optim.Adam(model.parameters(),lr=0.001)
# criterion = nn.CrossEntropyLoss()
# running_loss = 0.0
# for i,data in enumerate(dataloader,0) :
#     optimizer.zero_grad()
#     images,labels =data
    
#     outputs = model(images)
#     loss = criterion(outputs,labels)
    
#     loss.backward()
#     optimizer.step()
    
#     running_loss +=loss.item()
#     if i% 10 ==9:
#         print(f"batch: {i+1},loss: {running_loss/10}")
#         running_loss =0.0

# print("finished training")
    

Similar to the kaggle file:

In [17]:
import torch
import torch.nn as nn

class VNet1(nn.Module):
    def __init__(self, n_filters=8, dropout=0.2, batch_norm=True):
        super(VNet1, self).__init__()
        self.batch_norm = batch_norm
        self.dropout = dropout

        # Encoder
        self.conv1 = nn.Conv3d(3, n_filters, kernel_size=5, stride=1, padding=2)
        self.pool1 = nn.MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        self.conv2 = nn.Conv3d(n_filters, n_filters*2, kernel_size=5, stride=1, padding=2)
        self.pool2 = nn.MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        self.conv3 = nn.Conv3d(n_filters*2, n_filters*4, kernel_size=5, stride=1, padding=2)
        self.pool3 = nn.MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        self.conv4 = nn.Conv3d(n_filters*4, n_filters*8, kernel_size=5, stride=1, padding=2)
        self.pool4 = nn.MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)

        # Bottleneck
        self.conv5 = nn.Conv3d(n_filters*8, n_filters*16, kernel_size=5, stride=1, padding=2)
        self.deconv6 = nn.ConvTranspose3d(n_filters*16, n_filters*8, kernel_size=2, stride=2, padding=0)
        self.conv6 = nn.Conv3d(n_filters*16, n_filters*8, kernel_size=5, stride=1, padding=2)
        self.deconv7 = nn.ConvTranspose3d(n_filters*8, n_filters*4, kernel_size=2, stride=2, padding=0)
        self.conv7 = nn.Conv3d(n_filters*8, n_filters*4, kernel_size=5, stride=1, padding=2)
        self.deconv8 = nn.ConvTranspose3d(n_filters*4, n_filters*2, kernel_size=2, stride=2, padding=0)
        self.conv8 = nn.Conv3d(n_filters*4, n_filters*2, kernel_size=5, stride=1, padding=2)
        self.deconv9 = nn.ConvTranspose3d(n_filters*2, n_filters, kernel_size=2, stride=2, padding=0)
        self.conv9 = nn.Conv3d(n_filters*2, n_filters, kernel_size=5, stride=1, padding=2)
        self.conv10 = nn.Conv3d(n_filters*2, 4, kernel_size=1, stride=1, padding=0)

        # Activation and dropout
        self.activation = nn.ReLU(inplace=True)
        self.dropout_layer = nn.Dropout3d(p=dropout)

    def forward(self, x, n_filters=8):
        # Encoder
        x = self.conv1(x)
        if self.batch_norm:
            x = nn.BatchNorm3d(n_filters)(x)
        x = self.activation(x)
        x = self.dropout_layer(x)

        x = self.pool1(x)
        x = self.conv2(x)
        if self.batch_norm:
            x = nn.BatchNorm3d(n_filters*2)(x)
        x = self.activation(x)
        x = self.dropout_layer(x)

        x = self.pool2(x)
        x = self.conv3(x)
        if self.batch_norm:
            x = nn.BatchNorm3d(n_filters*4)(x)
        x = self.activation(x)
        x = self.dropout_layer(x)

        x = self.pool3(x)
        x = self.conv4(x)
        if self.batch_norm:
            x = nn.BatchNorm3d(n_filters*8)(x)
        x = self.activation(x)
        x = self.dropout_layer(x)

        # Bottleneck
        x = self.pool4(x)
        x = self.conv5(x)
        if self.batch_norm:
            x = nn.BatchNorm3d(n_filters*16)(x)
        x = self.activation(x)
        x = self.dropout_layer(x)

        x = self.deconv6(x)
        x = torch.cat([x, self.conv4(self.pool3(x))], dim=1)
        x = self.conv6(x)
        if self.batch_norm:
            x = nn.BatchNorm3d(n_filters*8)(x)
        x = self.activation(x)
        x = self.dropout_layer(x)

        x = self.deconv7(x)
        x = torch.cat([x, self.conv3(self.pool2(x))], dim=1)
        x = self.conv7(x)
        if self.batch_norm:
            x = nn.BatchNorm3d(n_filters*4)(x)
        x = self.activation(x)
        x = self.dropout_layer(x)

        x = self.deconv8(x)
        x = torch.cat([x, self.conv2(self.pool1(x))], dim=1)
        x = self.conv8(x)
        if self.batch_norm:
            x = nn.BatchNorm3d(n_filters*2)(x)
        x = self.activation(x)
        x = self.dropout_layer(x)

        x = self.deconv9(x)
        x = torch.cat([x, self.conv1(x)], dim=1)
        x = self.conv9(x)
        if self.batch_norm:
            x = nn.BatchNorm3d(n_filters)(x)
        x = self.activation(x)
        x = self.dropout_layer(x)

        x = self.conv10(torch.cat([x, input], dim=1))
        return x

       


Output size in each layer:

    self.conv1: Conv3d output size = [batch_size, n_filters, D, H, W]
    self.pool1: MaxPool3d output size = [batch_size, n_filters, D/2, H/2, W/2]
    self.conv2: Conv3d output size = [batch_size, 2*n_filters, D/2, H/2, W/2]
    self.pool2: MaxPool3d output size = [batch_size, 2*n_filters, D/4, H/4, W/4]
    self.conv3: Conv3d output size = [batch_size, 4*n_filters, D/4, H/4, W/4]
    self.pool3: MaxPool3d output size = [batch_size, 4*n_filters, D/8, H/8, W/8]
    self.conv4: Conv3d output size = [batch_size, 8*n_filters, D/8, H/8, W/8]
    self.pool4: MaxPool3d output size = [batch_size, 8*n_filters, D/16, H/16, W/16]
    self.conv5: Conv3d output size = [batch_size, 16*n_filters, D/16, H/16, W/16]
    self.deconv6: ConvTranspose3d output size = [batch_size, 8*n_filters, D/8, H/8, W/8]
    self.conv6: Conv3d output size = [batch_size, 8*n_filters, D/8, H/8, W/8]
    self.deconv7: ConvTranspose3d output size = [batch_size, 4*n_filters, D/4, H/4, W/4]
    self.conv7: Conv3d output size = [batch_size, 4*n_filters, D/4, H/4, W/4]
    self.deconv8: ConvTranspose3d output size = [batch_size, 2*n_filters, D/2, H/2, W/2]
    self.conv8: Conv3d output size = [batch_size, 2*n_filters, D/2, H/2, W/2]
    self.deconv9: ConvTranspose3d output size = [batch_size, n_filters, D, H, W]
    self.conv9: Conv3d output size = [batch_size, n_filters, D, H, W]
    self.conv10: Conv3d output size = [batch_size, 4, D, H, W]

In [18]:
model = VNet1()

In [None]:
import torch.optim as optim

optimizer = optim.Adam(model.parameters(),lr=0.001)
criterion = nn.CrossEntropyLoss()
running_loss = 0.0
for i,data in enumerate(dataloader,0) :
    optimizer.zero_grad()
    images,labels =data
    
    outputs = model(images)
    loss = criterion(outputs,labels)
    
    loss.backward()
    optimizer.step()
    
    running_loss +=loss.item()
    if i% 10 ==9:
        print(f"batch: {i+1},loss: {running_loss/10}")
        running_loss =0.0

print("finished training")

10
10
tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        ...,

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0.