In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch import optim
import torchvision
from torch.utils.data import DataLoader
from skimage import io

In [3]:
from skimage import io
import os

In [4]:
class double_conv_relu(nn.Module):
    
    def __init__(self, in_channels, out_channels, dropout=False):
        super(double_conv_relu, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.drop = nn.Dropout2d(p=0.2)
        self.norm = nn.BatchNorm2d(out_channels)
        self.ReLU = nn.ReLU(inplace=True)
        self.dropout = dropout
    def forward(self, x):
        out = self.conv1(x)
        out = self.norm(out)
        out = self.ReLU(out)
        out = self.conv2(out)
        out = self.norm(out)
        out = self.ReLU(out)
        if(self.dropout):
            out = self.drop(out)
        return out
    


class upsample(nn.Module):
    def __init__(self, in_channels, out_channels, bilinear=False):
        super(upsample, self).__init__()
        
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear')
        else:
            self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
        
    def forward(self, x):
        out = self.up(x)
        return out
    
class concatenate_conv(nn.Module):
    def __init__(self, layer_size):
        super(concatenate_conv, self).__init__()
        self.conv = double_conv_relu(layer_size*2, layer_size)
        
    def forward(self, encoder_layer, decoder_layer):
        out = torch.cat([encoder_layer, decoder_layer], dim=1)
        out = self.conv(out)
        return out
        

In [5]:
class unet(nn.Module):
    def __init__(self, in_channels, out_classes, dropout=False):
        super(unet, self).__init__()
        
        self.encoder_conv1 = double_conv_relu(in_channels, 64, dropout)
        self.encoder_conv2 = double_conv_relu(64, 128, dropout)
        self.encoder_conv3 = double_conv_relu(128, 256, dropout)
        self.encoder_conv4 = double_conv_relu(256, 512, dropout)
        self.encoder_conv5 = double_conv_relu(512, 512, dropout) #set out channels to 512 instead of 1024 for memory
        
        self.decoder_conv1 = concatenate_conv(512)
        self.decoder_conv2 = concatenate_conv(256)
        self.decoder_conv3 = concatenate_conv(128)
        self.decoder_conv4 = concatenate_conv(64)
        
        self.up1 = upsample(512, 512)
        self.up2 = upsample(512, 256)
        self.up3 = upsample(256, 128)
        self.up4 = upsample(128, 64)
        
        self.down = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.output_conv = nn.Conv2d(64, out_classes, kernel_size=1)
        
    def forward(self, x):
        encode1 = self.encoder_conv1(x)
        out = self.down(encode1)
        encode2 = self.encoder_conv2(out)
        out = self.down(encode2)
        encode3 = self.encoder_conv3(out)
        out = self.down(encode3)
        encode4 = self.encoder_conv4(out)
        out = self.down(encode4)
        encode5 = self.encoder_conv5(out)
        decode = self.up1(encode5)
        decode = self.decoder_conv1(encode4, decode)
        decode = self.up2(decode)
        decode = self.decoder_conv2(encode3, decode)
        decode = self.up3(decode)
        decode = self.decoder_conv3(encode2, decode)
        decode = self.up4(decode)
        decode = self.decoder_conv4(encode1, decode)
        out = self.output_conv(decode)
        
        return out
        
        

In [6]:
model = unet(1,2)
sum(p.numel() for p in model.parameters() if p.requires_grad)

20548738

In [7]:
from torchvision.transforms import ToTensor

def train_model(model, batch_size, epochs, lr=0.1, gpu=False):
    
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    criterion = nn.CrossEntropyLoss()
    
    for epoch in range(epochs):
        
        data_dir = os.path.join((os.getcwd()), 'data')
        labels = io.imread(os.path.join(data_dir, 'train-labels.tif')) #load training labels
        labels = ToTensor()(labels)
        labels.requires_grad = False
        labels = labels.transpose(0,1) #needed because of the TIF files
        
#         labels = labels.unsqueeze(1)
        labels = labels[0]
        labels = labels.unsqueeze(0)
        labels = torch.Tensor.long(labels)
        labels = Variable(labels)
        
        imgs = io.imread(os.path.join(data_dir, 'train-volume.tif')) #load training data
        imgs = ToTensor()(imgs)
        imgs = imgs.transpose(0,1)
        imgs.requires_grad = False
        imgs = imgs.unsqueeze(1)
        imgs = imgs[0]
        imgs = imgs.unsqueeze(0)
        imgs = Variable(imgs)
        if gpu:
            imgs = imgs.cuda()
            labels = labels.cuda()
            
        
        pred_masks = model(imgs)
        loss = criterion(pred_masks, labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        epoch_loss = loss.item()
        print('Epoch {}, loss: {}'.format(epoch, epoch_loss))

In [8]:
model = unet(1, 2)
train_model(model, 1, 1, gpu=False)



AttributeError: 'Variable' object has no attribute 'item'