In [1]:
import numpy
from stl import mesh

# Using an existing stl file:
seg_mesh = mesh.Mesh.from_file('data/segmentation299.stl')

In [29]:
seg_mesh.save('data/saved_mes.stl')

In [30]:
seg_mesh.y.shape

(312104, 3)

In [31]:
print(seg_mesh.x)
print(seg_mesh.y)
print(seg_mesh.z)

[[ 9.8728895 10.038902  10.015289 ]
 [10.038902   9.8728895  9.878741 ]
 [ 9.789877   9.8728895 10.015289 ]
 ...
 [-9.286607  -9.124687  -9.130307 ]
 [-8.956034  -9.010082  -8.969039 ]
 [-9.018819  -8.969039  -9.010082 ]]
[[ 19.85566   19.989784  19.802036]
 [ 19.989784  19.85566   19.961197]
 [ 19.590105  19.85566   19.802036]
 ...
 [-55.926697 -55.95069  -55.85525 ]
 [-55.953335 -55.935146 -55.97643 ]
 [-55.97107  -55.97643  -55.935146]]
[[-38.38391  -38.232117 -38.42846 ]
 [-38.232117 -38.38391  -38.189087]
 [-38.511353 -38.38391  -38.42846 ]
 ...
 [ 37.6062    37.808807  37.886204]
 [ 38.05275   38.015656  38.03947 ]
 [ 37.93748   38.03947   38.015656]]


In [36]:
seg_tensor = np.array([seg_mesh.x.T, seg_mesh.y.T, seg_mesh.z.T])

In [1]:
import numpy as np
import nibabel as nib

import torch
import torch.nn as nn
import torch.utils.data
import torchvision

from scipy.ndimage import zoom

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [3]:
def save_vol_as_nii(numpy_arr, loaded_file, path_to_save):
    empty_header = nib.Nifti1Header()
    Nifti1Image = nib.Nifti1Image(numpy_arr, loaded_file.affine, empty_header)
    nib.save(Nifti1Image, path_to_save)

In [4]:
brain_file = nib.load("data/sub-299_ses-20110422_desc-angio_N4bfc_brain_mask.nii.gz")
seg_file = nib.load("data/bad_segmentation.nii.gz")


brain_vol = np.array(brain_file.dataobj)
seg_vol = np.array(seg_file.dataobj)

In [5]:
print(brain_vol.shape)
print(seg_vol.shape)

(512, 512, 140)
(512, 512, 140)


In [6]:
seg_vol.max()

1

In [7]:
brain_vol_compressed = zoom(brain_vol, (0.25, 0.25, 0.45714285714285713))
seg_vol_compressed = zoom(seg_vol, (0.25, 0.25, 0.45714285714285713))

#brain_vol_compressed = zoom(brain_vol, (0.25/2.0, 0.25/2, 0.45714285714285713/2))
#seg_vol_compressed = zoom(seg_vol, (0.25/2, 0.25/2, 0.45714285714285713/2))


In [8]:
seg_vol_compressed[seg_vol_compressed>1] = 1
seg_vol_compressed.max()

1

In [9]:
brain_vol_compressed = torch.tensor(brain_vol_compressed).unsqueeze(0).unsqueeze(0)
seg_vol_compressed = torch.tensor(seg_vol_compressed).unsqueeze(0).unsqueeze(0)
print(brain_vol_compressed.shape)
print(seg_vol_compressed.shape)

torch.Size([1, 1, 128, 128, 64])
torch.Size([1, 1, 128, 128, 64])


In [10]:
class conv_block(nn.Module):
    """
    Convolution Block
    """

    def __init__(self, in_channels, out_channels, k_size=3, stride=1, padding=1, bias=True):
        super(conv_block, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=k_size,
                      stride=stride, padding=padding, bias=bias),
            nn.BatchNorm3d(num_features=out_channels),
            nn.ReLU(inplace=True),
            nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=k_size,
                      stride=stride, padding=padding, bias=bias),
            nn.BatchNorm3d(num_features=out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.conv(x)
        return x


class up_conv(nn.Module):
    """
    Up Convolution Block
    """

    # def __init__(self, in_ch, out_ch):
    def __init__(self, in_channels, out_channels, k_size=3, stride=1, padding=1, bias=True):
        super(up_conv, self).__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=k_size,
                      stride=stride, padding=padding, bias=bias),
            nn.BatchNorm3d(num_features=out_channels),
            nn.ReLU(inplace=True))

    def forward(self, x):
        x = self.up(x)
        return x


class U_Net(nn.Module):
    """
    UNet - Basic Implementation
    Input _ [batch * channel(# of channels of each image) * depth(# of frames) * height * width].
    Paper : https://arxiv.org/abs/1505.04597
    """

    def __init__(self, in_ch=1, out_ch=1):
        super(U_Net, self).__init__()

        n1 = 64 #TODO: make params
        filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]  # 64,128,256,512,1024

        self.Maxpool1 = nn.MaxPool3d(kernel_size=2, stride=2)
        self.Maxpool2 = nn.MaxPool3d(kernel_size=2, stride=2)
        self.Maxpool3 = nn.MaxPool3d(kernel_size=2, stride=2)
        self.Maxpool4 = nn.MaxPool3d(kernel_size=2, stride=2)

        self.Conv1 = conv_block(in_ch, filters[0])
        self.Conv2 = conv_block(filters[0], filters[1])
        self.Conv3 = conv_block(filters[1], filters[2])
        self.Conv4 = conv_block(filters[2], filters[3])
        self.Conv5 = conv_block(filters[3], filters[4])

        self.Up5 = up_conv(filters[4], filters[3])
        self.Up_conv5 = conv_block(filters[4], filters[3])

        self.Up4 = up_conv(filters[3], filters[2])
        self.Up_conv4 = conv_block(filters[3], filters[2])

        self.Up3 = up_conv(filters[2], filters[1])
        self.Up_conv3 = conv_block(filters[2], filters[1])

        self.Up2 = up_conv(filters[1], filters[0])
        self.Up_conv2 = conv_block(filters[1], filters[0])

        self.Conv = nn.Conv3d(filters[0], out_ch, kernel_size=1, stride=1, padding=0)

    self.active = torch.nn.Sigmoid()

    def forward(self, x):
        # print("unet")
        # print(x.shape)
        # print(padded.shape)

        e1 = self.Conv1(x)
        # print("conv1:")
        # print(e1.shape)

        e2 = self.Maxpool1(e1)
        e2 = self.Conv2(e2)
        # print("conv2:")
        # print(e2.shape)

        e3 = self.Maxpool2(e2)
        e3 = self.Conv3(e3)
        # print("conv3:")
        # print(e3.shape)

        e4 = self.Maxpool3(e3)
        e4 = self.Conv4(e4)
        # print("conv4:")
        # print(e4.shape)

        e5 = self.Maxpool4(e4)
        e5 = self.Conv5(e5)
        # print("conv5:")
        # print(e5.shape)

        d5 = self.Up5(e5)
        # print("d5:")
        # print(d5.shape)
        # print("e4:")
        # print(e4.shape)
        d5 = torch.cat((e4, d5), dim=1)
        d5 = self.Up_conv5(d5)
        # print("upconv5:")
        # print(d5.size)

        d4 = self.Up4(d5)
        # print("d4:")
        # print(d4.shape)
        d4 = torch.cat((e3, d4), dim=1)
        d4 = self.Up_conv4(d4)
        # print("upconv4:")
        # print(d4.shape)
        d3 = self.Up3(d4)
        d3 = torch.cat((e2, d3), dim=1)
        d3 = self.Up_conv3(d3)
        # print("upconv3:")
        # print(d3.shape)
        d2 = self.Up2(d3)
        d2 = torch.cat((e1, d2), dim=1)
        d2 = self.Up_conv2(d2)
        # print("upconv2:")
        # print(d2.shape)
        out = self.Conv(d2)
        # print("out:")
        # print(out.shape)
        # d1 = self.active(out)
        out = self.active(out)
        return [out]

class U_Net_DeepSup(nn.Module):
    """
    UNet - Basic Implementation
    Input _ [batch * channel(# of channels of each image) * depth(# of frames) * height * width].
    Paper : https://arxiv.org/abs/1505.04597
    """

    def __init__(self, in_ch=1, out_ch=1):
        super(U_Net_DeepSup, self).__init__()

        n1 = 64
        filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]  # 64,128,256,512,1024

        self.Maxpool1 = nn.MaxPool3d(kernel_size=2, stride=2)
        self.Maxpool2 = nn.MaxPool3d(kernel_size=2, stride=2)
        self.Maxpool3 = nn.MaxPool3d(kernel_size=2, stride=2)
        self.Maxpool4 = nn.MaxPool3d(kernel_size=2, stride=2)

        self.Conv1 = conv_block(in_ch, filters[0])
        self.Conv2 = conv_block(filters[0], filters[1])
        self.Conv3 = conv_block(filters[1], filters[2])
        self.Conv4 = conv_block(filters[2], filters[3])
        self.Conv5 = conv_block(filters[3], filters[4])

        #1x1x1 Convolution for Deep Supervision
        self.Conv_d3 = conv_block(filters[1], 1)
        self.Conv_d4 = conv_block(filters[2], 1)



        self.Up5 = up_conv(filters[4], filters[3])
        self.Up_conv5 = conv_block(filters[4], filters[3])

        self.Up4 = up_conv(filters[3], filters[2])
        self.Up_conv4 = conv_block(filters[3], filters[2])

        self.Up3 = up_conv(filters[2], filters[1])
        self.Up_conv3 = conv_block(filters[2], filters[1])

        self.Up2 = up_conv(filters[1], filters[0])
        self.Up_conv2 = conv_block(filters[1], filters[0])

        self.Conv = nn.Conv3d(filters[0], out_ch, kernel_size=1, stride=1, padding=0)

        for submodule in self.modules():
            submodule.register_forward_hook(self.nan_hook)

    # self.active = torch.nn.Sigmoid()

    def nan_hook(self, module, inp, output):
        for i, out in enumerate(output):
            nan_mask = torch.isnan(out)
            if nan_mask.any():
                print("In", self.__class__.__name__)
                print(module)
                raise RuntimeError(f"Found NAN in output {i} at indices: ", nan_mask.nonzero(), "where:", out[nan_mask.nonzero()[:, 0].unique(sorted=True)])

    def forward(self, x):
        # print("unet")
        # print(x.shape)
        # print(padded.shape)

        e1 = self.Conv1(x)
        # print("conv1:")
        # print(e1.shape)

        e2 = self.Maxpool1(e1)
        e2 = self.Conv2(e2)
        # print("conv2:")
        # print(e2.shape)

        e3 = self.Maxpool2(e2)
        e3 = self.Conv3(e3)
        # print("conv3:")
        # print(e3.shape)

        e4 = self.Maxpool3(e3)
        e4 = self.Conv4(e4)
        # print("conv4:")
        # print(e4.shape)

        e5 = self.Maxpool4(e4)
        e5 = self.Conv5(e5)
        # print("conv5:")
        # print(e5.shape)

        d5 = self.Up5(e5)
        # print("d5:")
        # print(d5.shape)
        # print("e4:")
        # print(e4.shape)
        d5 = torch.cat((e4, d5), dim=1)
        d5 = self.Up_conv5(d5)
        # print("upconv5:")
        # print(d5.size)

        d4 = self.Up4(d5)
        # print("d4:")
        # print(d4.shape)
        d4 = torch.cat((e3, d4), dim=1)
        d4 = self.Up_conv4(d4)
        d4_out  = self.Conv_d4(d4)
        
                
        # print("upconv4:")
        # print(d4.shape)
        d3 = self.Up3(d4)
        d3 = torch.cat((e2, d3), dim=1)
        d3 = self.Up_conv3(d3)        
        d3_out  = self.Conv_d3(d3)

        # print("upconv3:")
        # print(d3.shape)
        d2 = self.Up2(d3)
        d2 = torch.cat((e1, d2), dim=1)
        d2 = self.Up_conv2(d2)
        # print("upconv2:")
        # print(d2.shape)
        out = self.Conv(d2)
        # print("out:")
        # print(out.shape)
        # d1 = self.active(out)

        return [out, d3_out , d4_out]

In [11]:
model = U_Net()

In [12]:
total_params = sum(
    param.numel() for param in model.parameters()
)
total_params

103536449

In [13]:
model = model.to(device)

In [14]:
brain_vol_compressed = brain_vol_compressed.to(device)
seg_vol_compressed = seg_vol_compressed.to(device)

In [15]:
out = model(brain_vol_compressed)

In [16]:
print(out[0].shape)
sig = nn.Sigmoid()
out = sig(out[0])
out[out<0.5] = 0
out[out>0.5] = 1

torch.Size([1, 1, 128, 128, 64])


In [17]:
print(out.shape)
print(out.dtype)

print(seg_vol_compressed.shape)
print(seg_vol_compressed.dtype)

torch.Size([1, 1, 128, 128, 64])
torch.float32
torch.Size([1, 1, 128, 128, 64])
torch.int16


In [21]:
loss = nn.BCELoss(reduction='mean')
optim = torch.optim.Adam(model.parameters(), lr=3e-4, betas=(0.9, 0.999), eps=1e-06, weight_decay=0.0)

In [None]:
for epoch in range(10):
    optim.zero_grad()
    out = model(brain_vol_compressed)
    
    

In [19]:
loss(out, seg_vol_compressed.type(torch.float))

tensor(41.6293, device='cuda:0', grad_fn=<BinaryCrossEntropyBackward0>)

In [15]:
empty_header = nib.Nifti1Header()

In [6]:
Nifti1Image = nib.Nifti1Image(brain_vol_compressed, brain_file.affine, empty_header)

In [8]:
nib.save(Nifti1Image, 'data/brain_vol_compressed.nii.gz')

<h2> <b/>Saving example

In [11]:
path_to_brain = "data/sub-299_ses-20110422_desc-angio_N4bfc_brain_mask.nii.gz"
brain_file = nib.load(path_to_brain)
brain_vol = np.array(brain_file.dataobj)
brain_vol_compressed = zoom(brain_vol, (0.25, 0.25, 0.25))

path_to_save = 'data/brain_vol_compressed.nii.gz'
save_vol_as_nii(brain_vol_compressed, brain_file, path_to_save)