In [26]:
import torch
import os
import glob 
import torch.nn as nn

In [17]:
train_images = glob.glob('./rescaled/train/images/*.tif')
train_labels = glob.glob('./rescaled/train/labels/*.tif')

val_images = glob.glob('./rescaled/validation/images/*.tif')
val_labels = glob.glob('./rescaled/validation/labels/*.tif')

test_images = glob.glob('./rescaled/test/images/*.tif')
test_labels = glob.glob('./rescaled/test/labels/*.tif')

In [18]:
assert len(train_images)==len(train_labels)
assert len(val_images)==len(val_labels)
assert len(test_images)==len(test_labels)

print("Have", len(train_images), "training images and", len(val_images), "validation images")

In [20]:
#from functools import partial

def normalize(image, target, channel_wise=True):
    eps = 1.e-6
    image = image.astype('float32')
    chan_min = image.min(axis=(1, 2), keepdims=True)
    image -= chan_min
    chan_max = image.max(axis=(1, 2), keepdims=True)
    image /= (chan_max + eps)
    return image, target

#def to_tensor():
    
#image = train_images[1]
#normalize(image , 
#def DatasetWithTransform(images , labels , transform)


def get_loader(train_images, train_labels, patch_shape, split):
    data_paths = glob(os.path.join(args.input, split, "*.h5"))
    assert len(data_paths) > 0
    n_samples = 100 if split == "train" else 4
    loader = torch_em.default_segmentation_loader(
        data_paths, "image", data_paths, "labels",
        label_transform=microct_label_transform,
        batch_size=args.batch_size, patch_shape=patch_shape,
        num_workers=8, shuffle=True, is_seg_dataset=True,
        n_samples=n_samples
    )
    return loader

In [None]:
from torch_em.model import UNet3d

def train(train_images , train_labels):
    number_labels = 4
    model = UNet3d(in_channels=1, out_channels=n_out, final_activation="Sigmoid")

    patch_shape = [128, 128, 128]

    train_loader = get_loader(train_images , train_labels, patch_shape, "train")
    #val_loader = get_loader(val, patch_shape, "val")
    

In [27]:
class UNet(nn.Module):
    """ UNet implementation
    Arguments:
      in_channels: number of input channels
      out_channels: number of output channels
      final_activation: activation applied to the network output
    """
    
    # _conv_block and _upsampler are just helper functions to
    # construct the model.
    # encapsulating them like so also makes it easy to re-use
    # the model implementation with different architecture elements
    
    # Convolutional block for single layer of the decoder / encoder
    # we apply to 2d convolutions with relu activation
    def _conv_block(self, in_channels, out_channels):
        return nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
                             nn.ReLU(),
                             nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
                             nn.ReLU())       


    # upsampling via transposed 2d convolutions
    def _upsampler(self, in_channels, out_channels):
        return nn.ConvTranspose2d(in_channels, out_channels,
                                kernel_size=2, stride=2)
    
    def __init__(self, in_channels=1, out_channels=1, 
                 final_activation=None):
        super().__init__()
        
        # the depth (= number of encoder / decoder levels) is
        # hard-coded to 4
        self.depth = 4

        # the final activation must either be None or a Module
        if final_activation is not None:
            assert isinstance(final_activation, nn.Module), "Activation must be torch module"
        
        # all lists of conv layers (or other nn.Modules with parameters) must be wraped
        # itnto a nn.ModuleList
        
        # modules of the encoder path
        self.encoder = nn.ModuleList([self._conv_block(in_channels, 16),
                                      self._conv_block(16, 32),
                                      self._conv_block(32, 64),
                                      self._conv_block(64, 128)])
        # the base convolution block
        self.base = self._conv_block(128, 256)
        # modules of the decoder path
        self.decoder = nn.ModuleList([self._conv_block(256, 128),
                                      self._conv_block(128, 64),
                                      self._conv_block(64, 32),
                                      self._conv_block(32, 16)])
        
        # the pooling layers; we use 2x2 MaxPooling
        self.poolers = nn.ModuleList([nn.MaxPool2d(2) for _ in range(self.depth)])
        # the upsampling layers
        self.upsamplers = nn.ModuleList([self._upsampler(256, 128),
                                         self._upsampler(128, 64),
                                         self._upsampler(64, 32),
                                         self._upsampler(32, 16)])
        # output conv and activation
        # the output conv is not followed by a non-linearity, because we apply
        # activation afterwards
        self.out_conv = nn.Conv2d(16, out_channels, 1)
        self.activation = final_activation
    
    def forward(self, input):
        x = input
        # apply encoder path
        encoder_out = []
        for level in range(self.depth):
            x = self.encoder[level](x)
            encoder_out.append(x)
            x = self.poolers[level](x)

        # apply base
        x = self.base(x)
        
        # apply decoder path
        encoder_out = encoder_out[::-1]
        for level in range(self.depth):
            x = self.upsamplers[level](x)
            x = self.decoder[level](torch.cat((x, encoder_out[level]), dim=1))
        
        # apply output conv and activation (if given)
        x = self.out_conv(x)
        if self.activation is not None:
            x = self.activation(x)
        return x

In [22]:
# check if we have  a gpu
if torch.cuda.is_available():
    print("GPU is available")
    device = torch.device("cuda")
else:
    print("GPU is not available")
    device = torch.device("cpu")

GPU is available


In [29]:
# start a tensorboard writer
%load_ext tensorboard
from torch.utils.tensorboard import SummaryWriter

logger = SummaryWriter('runs/Unet')
%tensorboard --logdir runs

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


Reusing TensorBoard on port 6008 (pid 1318), started 0:05:13 ago. (Use '!kill 1318' to kill it.)

In [30]:
# build a default unet with sigmoid activation
# to normalize predictions to [0, 1]
from torch_em.model import UNet3d
model = UNet3d(1, 1, final_activation=nn.Sigmoid())
# move the model to GPU
model.to(device)


UNet3d(
  (encoder): Encoder(
    (blocks): ModuleList(
      (0): ConvBlock3d(
        (block): Sequential(
          (0): InstanceNorm3d(1, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (1): Conv3d(1, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
          (2): ReLU(inplace=True)
          (3): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (4): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
          (5): ReLU(inplace=True)
        )
      )
      (1): ConvBlock3d(
        (block): Sequential(
          (0): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (1): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
          (2): ReLU(inplace=True)
          (3): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (4): Conv3d(64, 64, kernel_size=(3, 3, 3