In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch import optim
import torchvision
from torch.utils.data import DataLoader
from skimage import io

In [49]:
from skimage import io
import os

In [80]:
class double_conv_relu(nn.Module):
    
    def __init__(self, in_channels, out_channels, dropout=False):
        super(double_conv_relu, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.drop = nn.Dropout2d(p=0.2)
        self.norm = nn.BatchNorm2d(out_channels)
        self.ReLU = nn.ReLU(inplace=True)
        self.dropout = dropout
    def forward(self, x):
        out = self.conv1(x)
        out = self.norm(out)
        out = self.ReLU(out)
        out = self.conv2(out)
        out = self.norm(out)
        out = self.ReLU(out)
        if(self.dropout):
            out = self.drop(out)
        return out
    


class upsample(nn.Module):
    def __init__(self, in_channels, out_channels, bilinear=False):
        super(upsample, self).__init__()
        
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear')
        else:
            self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
        
    def forward(self, x):
        out = self.up(x)
        return out
    
class concatenate_conv(nn.Module):
    def __init__(self, layer_size):
        super(concatenate_conv, self).__init__()
        self.conv = double_conv_relu(layer_size*2, layer_size)
        
    def forward(self, encoder_layer, decoder_layer):
        out = torch.cat([encoder_layer, decoder_layer], dim=1)
        out = self.conv(out)
        return out
        

In [81]:
class unet(nn.Module):
    def __init__(self, in_channels, out_classes, dropout=False):
        super(unet, self).__init__()
        
        self.encoder_conv1 = double_conv_relu(in_channels, 64, dropout)
        self.encoder_conv2 = double_conv_relu(64, 128, dropout)
        self.encoder_conv3 = double_conv_relu(128, 256, dropout)
        self.encoder_conv4 = double_conv_relu(256, 512, dropout)
        self.encoder_conv5 = double_conv_relu(512, 512, dropout)
        
        self.decoder_conv1 = concatenate_conv(512)
        self.decoder_conv2 = concatenate_conv(256)
        self.decoder_conv3 = concatenate_conv(128)
        self.decoder_conv4 = concatenate_conv(64)
        
        self.up1 = upsample(512, 512)
        self.up2 = upsample(512, 256)
        self.up3 = upsample(256, 128)
        self.up4 = upsample(128, 64)
        
        self.down = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.output_conv = nn.Conv2d(64, out_classes, kernel_size=1)
        
    def forward(self, x):
        encode1 = self.encoder_conv1(x)
        out = self.down(encode1)
        encode2 = self.encoder_conv2(out)
        out = self.down(encode2)
        encode3 = self.encoder_conv3(out)
        out = self.down(encode3)
        encode4 = self.encoder_conv4(out)
        out = self.down(encode4)
        encode5 = self.encoder_conv5(out)
        decode = self.up1(encode5)
        decode = self.decoder_conv1(encode4, decode)
        decode = self.up2(decode)
        decode = self.decoder_conv2(encode3, decode)
        decode = self.up3(decode)
        decode = self.decoder_conv3(encode2, decode)
        decode = self.up4(decode)
        decode = self.decoder_conv4(encode1, decode)
        out = self.output_conv(decode)
        
        return out
        
        

In [82]:
model = unet(1, 2)
sum(p.numel() for p in model.parameters() if p.requires_grad)

20548738

In [91]:
from torchvision.transforms import ToTensor

def train_model(model, batch_size, epochs, lr=0.1, gpu=False):
    
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    criterion = nn.BCELoss()
    
    for epoch in range(epochs):
        
        data_dir = os.path.join(os.path.dirname(os.getcwd()), 'data')
        labels = io.imread(os.path.join(data_dir, 'train-labels.tif')) #load training labels
        labels = ToTensor()(labels)
        labels.requires_grad = False
        labels = labels.transpose(0,1) #needed because of the TIF files
        
        labels = labels.unsqueeze(1)
        labels = labels[0]
        labels = labels.unsqueeze(0)
        labels = torch.Tensor.long(labels)
        labels = Variable(labels)
        
        imgs = io.imread(os.path.join(data_dir, 'train-volume.tif')) #load training data
        imgs = ToTensor()(imgs)
        imgs = imgs.transpose(0,1)
        imgs.requires_grad = False
        imgs = imgs.unsqueeze(1)
        imgs = imgs[0]
        imgs = imgs.unsqueeze(0)
        imgs = Variable(imgs)
        if gpu:
            imgs = imgs.cuda()
            labels = labels.cuda()
            
        
        pred_masks = model(imgs)
        sigmoid = nn.Sigmoid()
        loss = criterion(sigmoid(pred_masks), labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        epoch_loss = loss.item()
        print('Epoch {}, loss: {}'.format(epoch, epoch_loss))

In [93]:
model = unet(1, 2)
# model.cuda()
train_model(model, 1, 5, gpu=False)

  "Please ensure they have the same size.".format(target.size(), input.size()))


ValueError: Target and input must have the same number of elements. target nelement (262144) != input nelement (524288)

In [108]:
from torchvision.transforms import ToTensor
data_dir = os.path.join(os.path.dirname(os.getcwd()), 'data')
image = io.imread(os.path.join(data_dir, 'train-labels.tif'))
image = ToTensor()(image)

In [109]:
image.transpose(0,1).unsqueeze(0)


( 0 , 0 ,.,.) = 
   1   1   1  ...    1   1   1
   1   1   1  ...    1   1   1
   1   1   1  ...    1   1   1
     ...       ⋱       ...    
   1   1   1  ...    1   1   1
   1   1   1  ...    1   1   1
   1   1   1  ...    1   1   1

( 0 , 1 ,.,.) = 
   1   1   1  ...    1   1   1
   1   1   1  ...    1   1   1
   1   1   1  ...    1   1   1
     ...       ⋱       ...    
   1   1   1  ...    1   1   1
   1   1   1  ...    1   1   1
   1   1   1  ...    1   1   1

( 0 , 2 ,.,.) = 
   1   1   1  ...    1   1   1
   1   1   1  ...    1   1   1
   1   1   1  ...    1   1   1
     ...       ⋱       ...    
   1   1   1  ...    1   1   1
   1   1   1  ...    1   1   1
   1   1   1  ...    1   1   1
    ... 

( 0 ,27 ,.,.) = 
   1   1   1  ...    1   1   1
   1   1   1  ...    1   1   1
   1   1   1  ...    1   1   1
     ...       ⋱       ...    
   1   1   1  ...    1   1   1
   1   1   1  ...    1   1   1
   1   1   1  ...    1   1   1

( 0 ,28 ,.,.) = 
   1   1   1  ...    1   1   1
  

In [3]:
# sub-parts of the U-Net model

import torch
import torch.nn as nn
import torch.nn.functional as F


class double_conv(nn.Module):
    '''(conv => BN => ReLU) * 2'''
    def __init__(self, in_ch, out_ch):
        super(double_conv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.conv(x)
        return x


class inconv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(inconv, self).__init__()
        self.conv = double_conv(in_ch, out_ch)

    def forward(self, x):
        x = self.conv(x)
        return x


class down(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(down, self).__init__()
        self.mpconv = nn.Sequential(
            nn.MaxPool2d(2),
            double_conv(in_ch, out_ch)
        )

    def forward(self, x):
        x = self.mpconv(x)
        return x


class up(nn.Module):
    def __init__(self, in_ch, out_ch, bilinear=True):
        super(up, self).__init__()

        #  would be a nice idea if the upsampling could be learned too,
        #  but my machine do not have enough memory to handle all those weights

        
        self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2)

        self.conv = double_conv(in_ch, out_ch)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        diffX = x1.size()[2] - x2.size()[2]
        diffY = x1.size()[3] - x2.size()[3]
        x2 = F.pad(x2, (diffX // 2, int(diffX / 2),
                        diffY // 2, int(diffY / 2)))
        x = torch.cat([x2, x1], dim=1)
        x = self.conv(x)
        return x


class outconv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(outconv, self).__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, 1)

    def forward(self, x):
        x = self.conv(x)
        return x

In [74]:
# full assembly of the sub-parts to form the complete net


class UNet(nn.Module):
    def __init__(self, n_channels, n_classes):
        super(UNet, self).__init__()
        self.inc = inconv(n_channels, 64)
        self.down1 = down(64, 128)
        self.down2 = down(128, 256)
        self.down3 = down(256, 512)
        self.down4 = down(512, 512)
        self.up1 = up(1024, 256)
        self.up2 = up(512, 128)
        self.up3 = up(256, 64)
        self.up4 = up(128, 64)
        self.outc = outconv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        x = self.outc(x)
        return x