###Dataset###

Electron microscopy images dataset: https://www.ini.uzh.ch/~acardona/data.html 

More simplified data: train labesls [here](https://1drv.ms/u/s!AsGMiJ4RYU7UajzsvWLB8GXnWQ8), train images [here](https://1drv.ms/u/s!AsGMiJ4RYU7Ua0c5VaXY1EgRN88)

Implementation of "Crowdsourcing the creation of image segmentation algorithms for connectomics": https://www.frontiersin.org/articles/10.3389/fnana.2015.00142/full

Architecture: https://lmb.informatik.uni-freiburg.de/Publications/2015/RFB15a/u-net-architecture.png 

In [1]:
from skimage import io
import numpy as np

import torch
from torch import nn
from tqdm.auto import tqdm
import torch.nn.functional as F
from torchvision import transforms
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
torch.manual_seed(0)

<torch._C.Generator at 0x7f41b05568d0>

In [2]:
class ContractingBlock(nn.Module):
    '''
    ContractingBlock Class
    Performs two convolutions followed by a max pool operation.
    Values:
        input_channels: the number of channels to expect from a given input
    '''
    def __init__(self, input_channels):
        super(ContractingBlock, self).__init__()
        self.conv1 = nn.Conv2d(input_channels, input_channels*2, kernel_size=3)
        self.conv2 = nn.Conv2d(input_channels*2, input_channels*2, kernel_size=3)
        self.activation = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        '''
        Function for completing a forward pass of ContractingBlock: 
        Given an image tensor, completes a contracting block and returns the transformed tensor.
        Parameters:
            x: image tensor of shape (batch size, channels, height, width)
        '''
        x = self.conv1(x)
        x = self.activation(x)
        x = self.conv2(x)
        x = self.activation(x)
        x = self.maxpool(x)
        return x

In [3]:
def crop(image, new_shape=None):
    '''
    Function for cropping an image tensor: Given an image tensor and the new shape,
    crops to the center pixels.
    Parameters:
        image: image tensor of shape (batch size, channels, height, width)
        new_shape: a torch.Size object with the shape you want x to have
    '''

    _, _, h, w = image.shape
    _, _, h_new, w_new = new_shape

    ch, cw = h//2, w//2
    ch_new, cw_new = h_new//2, w_new//2
    
    x_dif = int(cw - cw_new)
    x2 = int(x_dif + w_new)
    
    y_dif = int(ch - ch_new)
    y2 = int(y_dif + h_new)
    
    return image[:, :, y_dif:y2, x_dif:x2]

In [4]:
class ExpandingBlock(nn.Module):
    '''
    ExpandingBlock Class
    Performs an upsampling, a convolution, a concatenation of its two inputs,
    followed by two more convolutions.
    Values:
        input_channels: the number of channels to expect from a given input
    '''
    def __init__(self, input_channels):
        super(ExpandingBlock, self).__init__()
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        self.conv1 = nn.Conv2d(input_channels, input_channels//2, kernel_size=2)
        self.conv2 = nn.Conv2d(input_channels, input_channels//2, kernel_size=3)
        self.conv3 = nn.Conv2d(input_channels//2, input_channels//2, kernel_size=3)
      
        self.activation = nn.ReLU() 
 
    def forward(self, x, skip_con_x):
        '''
        Function for completing a forward pass of ExpandingBlock: 
        Given an image tensor, completes an expanding block and returns the transformed tensor.
        Parameters:
            x: image tensor of shape (batch size, channels, height, width)
            skip_con_x: the image tensor from the contracting path (from the opposing block of x)
                    for the skip connection
        '''
        x = self.upsample(x)
        x = self.conv1(x)
        skip_con_x = crop(skip_con_x, x.shape)
        x = torch.cat([x, skip_con_x], axis=1)
        x = self.conv2(x)
        x = self.activation(x)
        x = self.conv3(x)
        x = self.activation(x)
        return x

In [5]:
class FeatureMapBlock(nn.Module):
    '''
    FeatureMapBlock Class
    The final layer of a UNet - 
    maps each pixel to a pixel with the correct number of output dimensions
    using a 1x1 convolution.
    Values:
        input_channels: the number of channels to expect from a given input
    '''
    def __init__(self, input_channels, output_channels):
        super(FeatureMapBlock, self).__init__()
        self.conv = nn.Conv2d(input_channels, output_channels, kernel_size=1)

    def forward(self, x):
        '''
        Function for completing a forward pass of FeatureMapBlock: 
        Given an image tensor, returns it mapped to the desired number of channels.
        Parameters:
            x: image tensor of shape (batch size, channels, height, width)
        '''
        x = self.conv(x)
        return x

In [6]:
class UNet(nn.Module):
    '''
    UNet Class
    A series of 4 contracting blocks followed by 4 expanding blocks to 
    transform an input image into the corresponding paired image, with an upfeature
    layer at the start and a downfeature layer at the end
    Values:
        input_channels: the number of channels to expect from a given input
        output_channels: the number of channels to expect for a given output
    '''
    def __init__(self, input_channels, output_channels, hidden_channels=64):
        super(UNet, self).__init__()

        self.upfeature = FeatureMapBlock(input_channels, hidden_channels)
        self.contract1 = ContractingBlock(hidden_channels)
        self.contract2 = ContractingBlock(hidden_channels * 2)
        self.contract3 = ContractingBlock(hidden_channels * 4)
        self.contract4 = ContractingBlock(hidden_channels * 8)
        self.expand1 = ExpandingBlock(hidden_channels * 16)
        self.expand2 = ExpandingBlock(hidden_channels * 8)
        self.expand3 = ExpandingBlock(hidden_channels * 4)
        self.expand4 = ExpandingBlock(hidden_channels * 2)
        self.downfeature = FeatureMapBlock(hidden_channels, output_channels)

    def forward(self, x):
        '''
        Function for completing a forward pass of UNet: 
        Given an image tensor, passes it through U-Net and returns the output.
        Parameters:
            x: image tensor of shape (batch size, channels, height, width)
        '''

        x0 = self.upfeature(x)
        x1 = self.contract1(x0)
        x2 = self.contract2(x1)
        x3 = self.contract3(x2)
        x4 = self.contract4(x3)
        x5 = self.expand1(x4, x3)
        x6 = self.expand2(x5, x2)
        x7 = self.expand3(x6, x1)
        x8 = self.expand4(x7, x0)
        xn = self.downfeature(x8)

        return xn

In [7]:
criterion = nn.BCEWithLogitsLoss()
n_epochs = 200
input_dim = 1
label_dim = 1
display_step = 20
batch_size = 4
lr = 0.0002
initial_shape = 512
target_shape = 373
device = 'cuda'

In [8]:
# import tarfile

# def extract_tar_bz2(filename, path="/content/"):
#     with tarfile.open(filename, "r:bz2") as tar:
#         tar.extractall(path)

# extract_tar_bz2('Seg.tar.bz2')          # image data is here
# extract_tar_bz2('synapses.tif.tar.bz2') # specific labels here! It creates synapses.tif

In [9]:
#!rsync -a /content/180-220-sub/    /content/180-220-int/  # move all 30 .tif files in one folder

In [11]:
# from skimage import io
# import numpy as np

### loop in all 30 tif files and concatenate across first (dim=0) dimension --> torch.Size([30, 1, 512, 512])
### do this to each file before concatenating:
# imagedata = io.imread('180-220-int/180-220-int-00.tif')[None, None, :, :] / 255
# imagedata = torch.from_numpy(imagedata).reshape(1, 1, 512, 512)

# labels = io.imread('synapses.tif', plugin="tifffile")[:, None, :, :] / 255 
# labels = torch.from_numpy(labels)

# print(imagedata.shape)
# print(labels.shape)

# labels = crop(labels, torch.Size([len(labels), 1, target_shape, target_shape]))
# dataset = torch.utils.data.TensorDataset(imagedata.type(torch.cuda.FloatTensor).to(device), labels.type(torch.cuda.FloatTensor).to(device))

### Alternative to all written in this and above 2 cells - if we already have files train-volume.tif and train-labels.tif 

imgs = torch.Tensor(io.imread('/content/imgs.tif'))[:, None, :, :] / 255
labels = torch.Tensor(io.imread('/content/labels.tif', plugin="tifffile"))[:, None, :, :] / 255
labels = crop(labels, torch.Size([len(labels), 1, target_shape, target_shape]))
dataset = torch.utils.data.TensorDataset(imgs, labels)

In [12]:
def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):
    '''
    Function for visualizing images: Given a tensor of images, number of images, and
    size per image, plots and prints the images in an uniform grid.
    '''
    image_shifted = image_tensor
    image_unflat = image_shifted.detach().cpu().view(-1, *size)
    image_grid = make_grid(image_unflat[:num_images], nrow=4)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()

In [None]:
def train():
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True)
    unet = UNet(input_dim, label_dim).to(device) if torch.cuda.is_available() else UNet(input_dim, label_dim)
    unet_opt = torch.optim.Adam(unet.parameters(), lr=lr)
    cur_step = 0

    for epoch in range(n_epochs):
        for real, labels in tqdm(dataloader):
            cur_batch_size = len(real)
          
            if torch.cuda.is_available():
                real = real.to(device)
                labels = labels.to(device)

            
            unet_opt.zero_grad()
            pred = unet(real)
            unet_loss = criterion(pred, labels)
            unet_loss.backward()
            unet_opt.step()

            if cur_step % display_step == 0:
                print(f"Epoch {epoch}: Step {cur_step}: U-Net loss: {unet_loss.item()}")
                show_tensor_images(
                    crop(real, torch.Size([len(real), 1, target_shape, target_shape])), 
                    size=(input_dim, target_shape, target_shape)
                )
                show_tensor_images(labels, size=(label_dim, target_shape, target_shape))
                show_tensor_images(torch.sigmoid(pred), size=(label_dim, target_shape, target_shape))
            cur_step += 1

train()

## Achived Result ##

<figure>
<img src=https://1drv.ms/u/s!AsGMiJ4RYU7UgRKtIAsAAfk3rTvM alt="micro" style="width:100%">
<figcaption align = "center"><b>image of cells, target segmentation, generated segmentation</b></figcaption>
</figure>