In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch import optim
import torchvision
from torch.utils.data import DataLoader
from skimage import io
import numpy as np
import gc
import os

In [4]:
class double_conv_relu(nn.Module):
    
    def __init__(self, in_channels, out_channels, dropout=False):
        super(double_conv_relu, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.drop = nn.Dropout2d(p=0.2)
        self.norm1 = nn.BatchNorm2d(out_channels)
        self.norm2 = nn.BatchNorm2d(out_channels)
        self.ReLU = nn.ReLU(inplace=True)
        self.dropout = dropout
    def forward(self, x):
        out = self.conv1(x)
        out = self.norm1(out)
        out = self.ReLU(out)
        out = self.conv2(out)
        out = self.norm2(out)
        out = self.ReLU(out)
        if(self.dropout):
            out = self.drop(out)
        return out
    


class upsample(nn.Module):
    def __init__(self, in_channels, out_channels, bilinear=False):
        super(upsample, self).__init__()
        
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear')
        else:
            self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
        
    def forward(self, x):
        out = self.up(x)
        return out
    
class concatenate_conv(nn.Module):
    def __init__(self, layer_size):
        super(concatenate_conv, self).__init__()
        self.conv = double_conv_relu(layer_size*2, layer_size)
        
    def forward(self, encoder_layer, decoder_layer):
        out = torch.cat([encoder_layer, decoder_layer], dim=1)
        out = self.conv(out)
        return out
        

In [5]:
class unet(nn.Module):
    def __init__(self, in_channels, out_classes, dropout=False):
        super(unet, self).__init__()
        
        self.encoder_conv1 = double_conv_relu(in_channels, 64, dropout)
        self.encoder_conv2 = double_conv_relu(64, 128, dropout)
        self.encoder_conv3 = double_conv_relu(128, 256, dropout)
        self.encoder_conv4 = double_conv_relu(256, 512, dropout)
        self.encoder_conv5 = double_conv_relu(512, 1024, dropout) #set out channels to 512 instead of 1024 for memory
        
        self.decoder_conv1 = concatenate_conv(512)
        self.decoder_conv2 = concatenate_conv(256)
        self.decoder_conv3 = concatenate_conv(128)
        self.decoder_conv4 = concatenate_conv(64)
        
        self.up1 = upsample(1024, 512)
        self.up2 = upsample(512, 256)
        self.up3 = upsample(256, 128)
        self.up4 = upsample(128, 64)
        
        self.down = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.output_conv = nn.Conv2d(64, out_classes, kernel_size=1)
        
    def forward(self, x):
        encode1 = self.encoder_conv1(x)
        out = self.down(encode1)
        encode2 = self.encoder_conv2(out)
        out = self.down(encode2)
        encode3 = self.encoder_conv3(out)
        out = self.down(encode3)
        encode4 = self.encoder_conv4(out)
        out = self.down(encode4)
        encode5 = self.encoder_conv5(out)
        decode = self.up1(encode5)
        decode = self.decoder_conv1(encode4, decode)
        decode = self.up2(decode)
        decode = self.decoder_conv2(encode3, decode)
        decode = self.up3(decode)
        decode = self.decoder_conv3(encode2, decode)
        decode = self.up4(decode)
        decode = self.decoder_conv4(encode1, decode)
        out = self.output_conv(decode)
        
        return out
        
        

In [6]:
model = unet(1,2)
sum(p.numel() for p in model.parameters() if p.requires_grad)

31042434

In [9]:
from torchvision.transforms import ToTensor

def train_model(model, batch_size, epochs, lr=0.1, gpu=False):
    
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    criterion = nn.CrossEntropyLoss()
    
    epoch_loss = 0
    
    for epoch in range(epochs):
        
#         data_dir = os.path.join((os.getcwd()), 'data')
#         labels = io.imread(os.path.join(data_dir, 'train-labels.tif')) #load training labels
#         labels = ToTensor()(labels)
#         labels.requires_grad = False
#         labels = labels.transpose(0,1) #needed because of the TIF files
        
# #         labels = labels.unsqueeze(1)
#         labels = labels[0]
#         labels = labels.unsqueeze(0)
#         labels = torch.Tensor.long(labels)
#         labels = Variable(labels)
        
#         imgs = io.imread(os.path.join(data_dir, 'train-volume.tif')) #load training data
#         imgs = ToTensor()(imgs)
#         imgs = imgs.transpose(0,1)
#         imgs.requires_grad = False
#         imgs = imgs.unsqueeze(1)
#         imgs = imgs[0]
#         imgs = imgs.unsqueeze(0)
#         imgs = Variable(imgs)
#         if gpu:
#             imgs = imgs.cuda()
#             labels = labels.cuda()

        x = Variable(torch.FloatTensor(np.random.random((2, 1, 256, 256))))
            
        
        pred_masks = model(x)
#         loss = criterion(pred_masks, labels)
        loss = torch.sum(pred_masks)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        gc.collect()
        del x, pred_masks
        
        epoch_loss += loss
        print('Epoch {}, loss: {}'.format(epoch, epoch_loss))

In [10]:
model = unet(1, 2)
train_model(model, 1, 5, gpu=False)

Epoch 0, loss: Variable containing:
-9346.3994
[torch.FloatTensor of size 1]

Epoch 1, loss: Variable containing:
-1.5957e+13
[torch.FloatTensor of size 1]



KeyboardInterrupt: 

In [9]:
model = UNet(3, depth=5, merge_mode='concat', in_channels=1)
x = Variable(torch.FloatTensor(np.random.random((1, 1, 512, 512))))
out = model(x)
loss = torch.sum(out)
loss.backward()

In [10]:
del x, out

In [11]:
del loss

In [2]:
data_dir = os.path.join((os.getcwd()), 'data')
labels = io.imread(os.path.join(data_dir, 'train-labels.tif')) #load training labels
labels = torchvision.transforms.ToTensor()(labels)
labels.requires_grad = False
labels = labels.transpose(0,1) #needed because of the TIF files

labels = torch.Tensor.long(labels)
# labels = Variable(labels)

imgs = io.imread(os.path.join(data_dir, 'train-volume.tif')) #load training data
imgs = torchvision.transforms.ToTensor()(imgs)
imgs = imgs.transpose(0,1)
imgs.requires_grad = False
imgs = imgs.unsqueeze(1)
# imgs = Variable(imgs)

In [30]:
print(labels)

Variable containing:
( 0 ,.,.) = 
   1   1   1  ...    1   1   1
   1   1   1  ...    1   1   1
   1   1   1  ...    1   1   1
     ...       ⋱       ...    
   1   1   1  ...    1   1   1
   1   1   1  ...    1   1   1
   1   1   1  ...    1   1   1

( 1 ,.,.) = 
   1   1   1  ...    1   1   1
   1   1   1  ...    1   1   1
   1   1   1  ...    1   1   1
     ...       ⋱       ...    
   1   1   1  ...    1   1   1
   1   1   1  ...    1   1   1
   1   1   1  ...    1   1   1

( 2 ,.,.) = 
   1   1   1  ...    1   1   1
   1   1   1  ...    1   1   1
   1   1   1  ...    1   1   1
     ...       ⋱       ...    
   1   1   1  ...    1   1   1
   1   1   1  ...    1   1   1
   1   1   1  ...    1   1   1
... 

(27 ,.,.) = 
   1   1   1  ...    1   1   1
   1   1   1  ...    1   1   1
   1   1   1  ...    1   1   1
     ...       ⋱       ...    
   1   1   1  ...    1   1   1
   1   1   1  ...    1   1   1
   1   1   1  ...    1   1   1

(28 ,.,.) = 
   1   1   1  ...    1   1   1
   1  

In [12]:
import matplotlib.pyplot as plt

In [19]:
torchvision.utils.save_image(labels[25], 'label.png')

In [15]:
test = imgs[25]
test = test.unsqueeze(0)
test = Variable(test)
out = model(test)

In [16]:
_, indices = torch.max(out, 1)

In [11]:
model.zero_grad()

In [62]:
colors = np.array([[255/255,255/255,255/255], [100/255,100/255,100/255]])

In [81]:
colorsTensor = torch.Tensor(colors)

In [100]:
output = colorsTensor[indices.data.view(-1)].view(512, 512, 3)

In [98]:
output = output.permute(2, 0, 1)

In [117]:
dataset = torch.utils.data.TensorDataset(imgs.data, labels.data)

In [118]:
dataloader = torch.utils.data.DataLoader(dataset)

In [124]:
for (batch, label) in dataloader:
    print(label)
    break


( 0 ,.,.) = 
   1   1   1  ...    1   1   1
   1   1   1  ...    1   1   1
   1   1   1  ...    1   1   1
     ...       ⋱       ...    
   1   1   1  ...    1   1   1
   1   1   1  ...    1   1   1
   1   1   1  ...    1   1   1
[torch.LongTensor of size 1x512x512]



In [3]:
import train

In [4]:
from model.unet import unet

In [5]:
model = unet(1,2)

In [6]:
train.train(model, imgs, labels, batch_size=1, epochs=1)

tensor(0.6572)
tensor(0.5146)
tensor(0.4106)
tensor(0.3620)
tensor(0.3712)
Epoch 0, loss: 2.315690279006958


In [154]:
import importlib

In [172]:
importlib.reload(train)

<module 'train' from 'D:\\Machine Learning\\UNet_pytorch\\train.py'>

In [171]:
imgs


( 0 , 0 ,.,.) = 
  0.4941  0.5412  0.5529  ...   0.6078  0.6392  0.5922
  0.4196  0.4627  0.5294  ...   0.5843  0.6039  0.5529
  0.4784  0.4824  0.5686  ...   0.6588  0.6510  0.6275
           ...             ⋱             ...          
  0.6314  0.6000  0.5647  ...   0.5451  0.6196  0.7176
  0.6431  0.5843  0.5686  ...   0.4510  0.5529  0.6431
  0.6941  0.6000  0.5373  ...   0.4706  0.5804  0.6706
      ⋮  

( 1 , 0 ,.,.) = 
  0.4510  0.4510  0.4627  ...   0.4902  0.5882  0.5451
  0.5137  0.5176  0.4980  ...   0.5725  0.6627  0.5451
  0.5098  0.5137  0.4314  ...   0.5647  0.6078  0.5020
           ...             ⋱             ...          
  0.4863  0.4118  0.4510  ...   0.6627  0.6941  0.7137
  0.4902  0.4353  0.5137  ...   0.6745  0.6980  0.6706
  0.5373  0.5608  0.5961  ...   0.6824  0.6745  0.6078
      ⋮  

( 2 , 0 ,.,.) = 
  0.6118  0.6431  0.6824  ...   0.6431  0.6431  0.6941
  0.7255  0.7882  0.7686  ...   0.6392  0.6353  0.6863
  0.7882  0.8078  0.8157  ...   0.6980  0.7333

In [13]:
gc.collect()

0