In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch import optim
import torchvision
from torch.utils.data import DataLoader
from skimage import io
import numpy as np

In [2]:
from skimage import io
import os

In [3]:
class double_conv_relu(nn.Module):
    
    def __init__(self, in_channels, out_channels, dropout=False):
        super(double_conv_relu, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.drop = nn.Dropout2d(p=0.2)
        self.norm1 = nn.BatchNorm2d(out_channels)
        self.norm2 = nn.BatchNorm2d(out_channels)
        self.ReLU = nn.ReLU(inplace=True)
        self.dropout = dropout
    def forward(self, x):
        out = self.conv1(x)
        out = self.norm1(out)
        out = self.ReLU(out)
        out = self.conv2(out)
        out = self.norm2(out)
        out = self.ReLU(out)
        if(self.dropout):
            out = self.drop(out)
        return out
    


class upsample(nn.Module):
    def __init__(self, in_channels, out_channels, bilinear=False):
        super(upsample, self).__init__()
        
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear')
        else:
            self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
        
    def forward(self, x):
        out = self.up(x)
        return out
    
class concatenate_conv(nn.Module):
    def __init__(self, layer_size):
        super(concatenate_conv, self).__init__()
        self.conv = double_conv_relu(layer_size*2, layer_size)
        
    def forward(self, encoder_layer, decoder_layer):
        out = torch.cat([encoder_layer, decoder_layer], dim=1)
        out = self.conv(out)
        return out
        

In [6]:
class unet(nn.Module):
    def __init__(self, in_channels, out_classes, dropout=False):
        super(unet, self).__init__()
        
        self.encoder_conv1 = double_conv_relu(in_channels, 64, dropout)
        self.encoder_conv2 = double_conv_relu(64, 128, dropout)
        self.encoder_conv3 = double_conv_relu(128, 256, dropout)
        self.encoder_conv4 = double_conv_relu(256, 512, dropout)
        self.encoder_conv5 = double_conv_relu(512, 1024, dropout) #set out channels to 512 instead of 1024 for memory
        
        self.decoder_conv1 = concatenate_conv(512)
        self.decoder_conv2 = concatenate_conv(256)
        self.decoder_conv3 = concatenate_conv(128)
        self.decoder_conv4 = concatenate_conv(64)
        
        self.up1 = upsample(1024, 512)
        self.up2 = upsample(512, 256)
        self.up3 = upsample(256, 128)
        self.up4 = upsample(128, 64)
        
        self.down = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.output_conv = nn.Conv2d(64, out_classes, kernel_size=1)
        
    def forward(self, x):
        encode1 = self.encoder_conv1(x)
        out = self.down(encode1)
        encode2 = self.encoder_conv2(out)
        out = self.down(encode2)
        encode3 = self.encoder_conv3(out)
        out = self.down(encode3)
        encode4 = self.encoder_conv4(out)
        out = self.down(encode4)
        encode5 = self.encoder_conv5(out)
        decode = self.up1(encode5)
        decode = self.decoder_conv1(encode4, decode)
        decode = self.up2(decode)
        decode = self.decoder_conv2(encode3, decode)
        decode = self.up3(decode)
        decode = self.decoder_conv3(encode2, decode)
        decode = self.up4(decode)
        decode = self.decoder_conv4(encode1, decode)
        out = self.output_conv(decode)
        
        return out
        
        

In [8]:
model = unet(1,2)
sum(p.numel() for p in model.parameters() if p.requires_grad)

31042434

In [9]:
from torchvision.transforms import ToTensor

def train_model(model, batch_size, epochs, lr=0.1, gpu=False):
    
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    criterion = nn.CrossEntropyLoss()
    
    epoch_loss = 0
    
    for epoch in range(epochs):
        
#         data_dir = os.path.join((os.getcwd()), 'data')
#         labels = io.imread(os.path.join(data_dir, 'train-labels.tif')) #load training labels
#         labels = ToTensor()(labels)
#         labels.requires_grad = False
#         labels = labels.transpose(0,1) #needed because of the TIF files
        
# #         labels = labels.unsqueeze(1)
#         labels = labels[0]
#         labels = labels.unsqueeze(0)
#         labels = torch.Tensor.long(labels)
#         labels = Variable(labels)
        
#         imgs = io.imread(os.path.join(data_dir, 'train-volume.tif')) #load training data
#         imgs = ToTensor()(imgs)
#         imgs = imgs.transpose(0,1)
#         imgs.requires_grad = False
#         imgs = imgs.unsqueeze(1)
#         imgs = imgs[0]
#         imgs = imgs.unsqueeze(0)
#         imgs = Variable(imgs)
#         if gpu:
#             imgs = imgs.cuda()
#             labels = labels.cuda()

        x = Variable(torch.FloatTensor(np.random.random((2, 1, 256, 256))))
            
        
        pred_masks = model(x)
#         loss = criterion(pred_masks, labels)
        loss = torch.sum(pred_masks)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        gc.collect()
        del x, pred_masks
        
        epoch_loss += loss
        print('Epoch {}, loss: {}'.format(epoch, epoch_loss))

In [10]:
model = unet(1, 2)
train_model(model, 1, 5, gpu=False)

Epoch 0, loss: Variable containing:
-9346.3994
[torch.FloatTensor of size 1]

Epoch 1, loss: Variable containing:
-1.5957e+13
[torch.FloatTensor of size 1]



KeyboardInterrupt: 

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class ConvBnRelu(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, padding, stride):
        super(ConvBnRelu, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding, stride=stride)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x


class StackEncoder(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(StackEncoder, self).__init__()
        self.convr1 = ConvBnRelu(in_channels, out_channels, kernel_size=(3, 3), stride=1, padding=0)
        self.convr2 = ConvBnRelu(out_channels, out_channels, kernel_size=(3, 3), stride=1, padding=0)
        self.maxPool = nn.MaxPool2d(kernel_size=(2, 2), stride=2)

    def forward(self, x):
        x = self.convr1(x)
        x = self.convr2(x)
        x_trace = x
        x = self.maxPool(x)
        return x, x_trace


class StackDecoder(nn.Module):
    def __init__(self, in_channels, out_channels, upsample_size):
        super(StackDecoder, self).__init__()

        self.upSample = nn.Upsample(size=upsample_size, scale_factor=(2, 2), mode="bilinear")
        self.convr1 = ConvBnRelu(in_channels, out_channels, kernel_size=(3, 3), stride=1, padding=0)
        # Crop + concat step between these 2
        self.convr2 = ConvBnRelu(in_channels, out_channels, kernel_size=(3, 3), stride=1, padding=0)

    def _crop_concat(self, upsampled, bypass):
        """
         Crop y to the (h, w) of x and concat them.
         Used for the expansive path.
        Returns:
            The concatenated tensor
        """
        c = (bypass.size()[2] - upsampled.size()[2]) // 2
        bypass = F.pad(bypass, (-c, -c, -c, -c))

        return torch.cat((upsampled, bypass), 1)

    def forward(self, x, down_tensor):
        x = self.upSample(x)
        x = self.convr1(x)
        x = self._crop_concat(x, down_tensor)
        x = self.convr2(x)
        return x


class UNetOriginal(nn.Module):
    def __init__(self, in_shape):
        super(UNetOriginal, self).__init__()
        channels, height, width = in_shape

        self.down1 = StackEncoder(channels, 64)
        self.down2 = StackEncoder(64, 128)
        self.down3 = StackEncoder(128, 256)
        self.down4 = StackEncoder(256, 512)

        self.center = nn.Sequential(
            ConvBnRelu(512, 1024, kernel_size=(3, 3), stride=1, padding=0),
            ConvBnRelu(1024, 1024, kernel_size=(3, 3), stride=1, padding=0)
        )

        self.up1 = StackDecoder(in_channels=1024, out_channels=512, upsample_size=(56, 56))
        self.up2 = StackDecoder(in_channels=512, out_channels=256, upsample_size=(104, 104))
        self.up3 = StackDecoder(in_channels=256, out_channels=128, upsample_size=(200, 200))
        self.up4 = StackDecoder(in_channels=128, out_channels=64, upsample_size=(392, 392))

        # 1x1 convolution at the last layer
        # Different from the paper is the output size here
        self.output_seg_map = nn.Conv2d(64, 1, kernel_size=(1, 1), padding=0, stride=1)

    def forward(self, x):
        x, x_trace1 = self.down1(x)  # Calls the forward() method of each layer
        x, x_trace2 = self.down2(x)
        x, x_trace3 = self.down3(x)
        x, x_trace4 = self.down4(x)

        x = self.center(x)

        x = self.up1(x, x_trace4)
        x = self.up2(x, x_trace3)
        x = self.up3(x, x_trace2)
        x = self.up4(x, x_trace1)

        out = self.output_seg_map(x)
        out = torch.squeeze(out, dim=1)
        return out

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from collections import OrderedDict
from torch.nn import init
import numpy as np

def conv3x3(in_channels, out_channels, stride=1, 
            padding=1, bias=True, groups=1):    
    return nn.Conv2d(
        in_channels,
        out_channels,
        kernel_size=3,
        stride=stride,
        padding=padding,
        bias=bias,
        groups=groups)

def upconv2x2(in_channels, out_channels, mode='transpose'):
    if mode == 'transpose':
        return nn.ConvTranspose2d(
            in_channels,
            out_channels,
            kernel_size=2,
            stride=2)
    else:
        # out_channels is always going to be the same
        # as in_channels
        return nn.Sequential(
            nn.Upsample(mode='bilinear', scale_factor=2),
            conv1x1(in_channels, out_channels))

def conv1x1(in_channels, out_channels, groups=1):
    return nn.Conv2d(
        in_channels,
        out_channels,
        kernel_size=1,
        groups=groups,
        stride=1)


class DownConv(nn.Module):
    """
    A helper Module that performs 2 convolutions and 1 MaxPool.
    A ReLU activation follows each convolution.
    """
    def __init__(self, in_channels, out_channels, pooling=True):
        super(DownConv, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.pooling = pooling

        self.conv1 = conv3x3(self.in_channels, self.out_channels)
        self.conv2 = conv3x3(self.out_channels, self.out_channels)

        if self.pooling:
            self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        before_pool = x
        if self.pooling:
            x = self.pool(x)
        return x, before_pool


class UpConv(nn.Module):
    """
    A helper Module that performs 2 convolutions and 1 UpConvolution.
    A ReLU activation follows each convolution.
    """
    def __init__(self, in_channels, out_channels, 
                 merge_mode='concat', up_mode='transpose'):
        super(UpConv, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.merge_mode = merge_mode
        self.up_mode = up_mode

        self.upconv = upconv2x2(self.in_channels, self.out_channels, 
            mode=self.up_mode)

        if self.merge_mode == 'concat':
            self.conv1 = conv3x3(
                2*self.out_channels, self.out_channels)
        else:
            # num of input channels to conv2 is same
            self.conv1 = conv3x3(self.out_channels, self.out_channels)
        self.conv2 = conv3x3(self.out_channels, self.out_channels)


    def forward(self, from_down, from_up):
        """ Forward pass
        Arguments:
            from_down: tensor from the encoder pathway
            from_up: upconv'd tensor from the decoder pathway
        """
        from_up = self.upconv(from_up)
        if self.merge_mode == 'concat':
            x = torch.cat((from_up, from_down), 1)
        else:
            x = from_up + from_down
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        return x


class UNet(nn.Module):
    """ `UNet` class is based on https://arxiv.org/abs/1505.04597

    The U-Net is a convolutional encoder-decoder neural network.
    Contextual spatial information (from the decoding,
    expansive pathway) about an input tensor is merged with
    information representing the localization of details
    (from the encoding, compressive pathway).

    Modifications to the original paper:
    (1) padding is used in 3x3 convolutions to prevent loss
        of border pixels
    (2) merging outputs does not require cropping due to (1)
    (3) residual connections can be used by specifying
        UNet(merge_mode='add')
    (4) if non-parametric upsampling is used in the decoder
        pathway (specified by upmode='upsample'), then an
        additional 1x1 2d convolution occurs after upsampling
        to reduce channel dimensionality by a factor of 2.
        This channel halving happens with the convolution in
        the tranpose convolution (specified by upmode='transpose')
    """

    def __init__(self, num_classes, in_channels=3, depth=5, 
                 start_filts=64, up_mode='transpose', 
                 merge_mode='concat'):
        """
        Arguments:
            in_channels: int, number of channels in the input tensor.
                Default is 3 for RGB images.
            depth: int, number of MaxPools in the U-Net.
            start_filts: int, number of convolutional filters for the 
                first conv.
            up_mode: string, type of upconvolution. Choices: 'transpose'
                for transpose convolution or 'upsample' for nearest neighbour
                upsampling.
        """
        super(UNet, self).__init__()

        if up_mode in ('transpose', 'upsample'):
            self.up_mode = up_mode
        else:
            raise ValueError("\"{}\" is not a valid mode for "
                             "upsampling. Only \"transpose\" and "
                             "\"upsample\" are allowed.".format(up_mode))
    
        if merge_mode in ('concat', 'add'):
            self.merge_mode = merge_mode
        else:
            raise ValueError("\"{}\" is not a valid mode for"
                             "merging up and down paths. "
                             "Only \"concat\" and "
                             "\"add\" are allowed.".format(up_mode))

        # NOTE: up_mode 'upsample' is incompatible with merge_mode 'add'
        if self.up_mode == 'upsample' and self.merge_mode == 'add':
            raise ValueError("up_mode \"upsample\" is incompatible "
                             "with merge_mode \"add\" at the moment "
                             "because it doesn't make sense to use "
                             "nearest neighbour to reduce "
                             "depth channels (by half).")

        self.num_classes = num_classes
        self.in_channels = in_channels
        self.start_filts = start_filts
        self.depth = depth

        self.down_convs = []
        self.up_convs = []

        # create the encoder pathway and add to a list
        for i in range(depth):
            ins = self.in_channels if i == 0 else outs
            outs = self.start_filts*(2**i)
            pooling = True if i < depth-1 else False

            down_conv = DownConv(ins, outs, pooling=pooling)
            self.down_convs.append(down_conv)

        # create the decoder pathway and add to a list
        # - careful! decoding only requires depth-1 blocks
        for i in range(depth-1):
            ins = outs
            outs = ins // 2
            up_conv = UpConv(ins, outs, up_mode=up_mode,
                merge_mode=merge_mode)
            self.up_convs.append(up_conv)

        self.conv_final = conv1x1(outs, self.num_classes)

        # add the list of modules to current module
        self.down_convs = nn.ModuleList(self.down_convs)
        self.up_convs = nn.ModuleList(self.up_convs)

        self.reset_params()

    @staticmethod
    def weight_init(m):
        if isinstance(m, nn.Conv2d):
            init.xavier_normal(m.weight)
            init.constant(m.bias, 0)


    def reset_params(self):
        for i, m in enumerate(self.modules()):
            self.weight_init(m)


    def forward(self, x):
        encoder_outs = []
         
        # encoder pathway, save outputs for merging
        for i, module in enumerate(self.down_convs):
            x, before_pool = module(x)
            encoder_outs.append(before_pool)

        for i, module in enumerate(self.up_convs):
            before_pool = encoder_outs[-(i+2)]
            x = module(before_pool, x)
        
        # No softmax is used. This means you need to use
        # nn.CrossEntropyLoss is your training script,
        # as this module includes a softmax already.
        x = self.conv_final(x)
        return x


In [9]:
model = UNet(3, depth=5, merge_mode='concat', in_channels=1)
x = Variable(torch.FloatTensor(np.random.random((1, 1, 512, 512))))
out = model(x)
loss = torch.sum(out)
loss.backward()

In [10]:
del x, out

In [11]:
del loss

In [5]:
import gc