In [1]:
import os
import numpy as np
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from lungDataLoader import lung424

class UNET(nn.Module):

        def __init__(self):
            super(UNET, self).__init__()
            #conv block 1
            self.conv1_1 = nn.Conv2d(3, 32, 3, padding=1, stride=1, dilation=1)
            self.conv1_2 = nn.Conv2d(32, 32, 3, padding=1, stride=1, dilation=1)
            #conv block 2
            self.conv2_1 = nn.Conv2d(32, 64, 3, padding=1, stride=1, dilation=1)
            self.conv2_2 = nn.Conv2d(64, 64, 3, padding=1, stride=1, dilation=1)
            #conv block 3
            self.conv3_1 = nn.Conv2d(64, 128, 3, padding=1, stride=1, dilation=1)
            self.conv3_2 = nn.Conv2d(128, 128, 3, padding=1, stride=1, dilation=1)
            #conv block 4
            self.conv4_1 = nn.Conv2d(128, 256, 3, padding=1, stride=1, dilation=1)
            self.conv4_2 = nn.Conv2d(256, 256, 3, padding=1, stride=1, dilation=1)
            #deepeast conv block
            self.conv5_1 = nn.Conv2d(256, 512, 3, padding=1, stride=1, dilation=1)
            self.conv5_2 = nn.Conv2d(512, 512, 3, padding=1, stride=1, dilation=1)
            #reserver conv block0
            self.deconv5_1 = nn.Conv2d(512, 256, 1, padding=0, stride=1, dilation=1)
            self.deconv5_2 = nn.Conv2d(512, 256, 3, padding=1, stride=1, dilation=1)
            self.deconv5_3 = nn.Conv2d(256, 256, 3, padding=1, stride=1, dilation=1)
            #reverse conv block 1
            self.deconv4_1 = nn.Conv2d(256, 128, 1, padding=0, stride=1, dilation=1)
            self.deconv4_2 = nn.Conv2d(256, 128, 3, padding=1, stride=1, dilation=1)
            self.deconv4_3 = nn.Conv2d(128, 128, 3, padding=1, stride=1, dilation=1)
            #reverse conv block 2
            self.deconv3_1 = nn.Conv2d(128, 64, 1, padding=0, stride=1, dilation=1)
            self.deconv3_2 = nn.Conv2d(128, 64, 3, padding=1, stride=1, dilation=1)
            self.deconv3_3 = nn.Conv2d(64, 64, 3, padding=1, stride=1, dilation=1)
            #reverse conv block 3
            self.deconv2_1 = nn.Conv2d(64, 32, 1, padding=0, stride=1, dilation=1)
            self.deconv2_2 = nn.Conv2d(64, 32, 3, padding=1, stride=1, dilation=1)
            self.deconv2_3 = nn.Conv2d(32, 32, 3, padding=1, stride=1, dilation=1)
            self.deconv1 = nn.Conv2d(32, 1, 1, padding=0, stride=1, dilation=1)
            self.upsample = nn.UpsamplingBilinear2d(scale_factor=2)

        def forward(self, x):
            x1 = F.relu(self.conv1_2(F.relu(self.conv1_1(x)))) # 224x224, 32
            x2 = F.relu(self.conv2_2(F.relu(self.conv2_1(F.max_pool2d(x1, 2))))) # 112x112, 64
            x3 = F.relu(self.conv3_2(F.relu(self.conv3_1(F.max_pool2d(x2, 2))))) # 56x56, 128
            x4 = F.relu(self.conv4_2(F.relu(self.conv4_1(F.max_pool2d(x3, 2))))) # 28x28, 256
            x = F.relu(self.conv5_2(F.relu(self.conv5_1(F.max_pool2d(x4, 2))))) # 14x14, 512
            # 28x28, 256            
            x = F.relu(self.deconv5_3(F.relu(self.deconv5_2(torch.cat([x4, self.upsample(F.relu(self.deconv5_1(x)))], dim=1))))) 
            # 56x56, 128            
            x = F.relu(self.deconv4_3(F.relu(self.deconv4_2(torch.cat([x3, self.upsample(F.relu(self.deconv4_1(x)))], dim=1))))) 
            # 112x112, 64
            x = F.relu(self.deconv3_3(F.relu(self.deconv3_2(torch.cat([x2, self.upsample(F.relu(self.deconv3_1(x)))], dim=1))))) 
            # 224x224, 32
            x = F.relu(self.deconv2_3(F.relu(self.deconv2_2(torch.cat([x1, self.upsample(F.relu(self.deconv2_1(x)))], dim=1))))) 
            x = self.deconv1(x) # 224x224, 1
            return x

In [2]:
# test the network
unet = UNET()
criterion = nn.MSELoss()
loss = 0

lungDataLoader = DataLoader(lung424, shuffle=True, batch_size=5)
dataiter = iter(lungDataLoader)
img, target = dataiter.next()
img, target = Variable(img), Variable(target)

out = unet(img)
# out = torch.squeeze(out, dim=1)
# target = torch.squeeze(out, dim=1)

print('Input tensor shape is {}'.format(img.size()))
print('Ouput tensor shape is {}'.format(out.size()))
print('Target tensor shape is {}'.format(target.size()))

loss = criterion(out,target)

print('The loss is {}'.format(loss))


Input tensor shape is torch.Size([5, 3, 224, 224])
Ouput tensor shape is torch.Size([5, 1, 224, 224])
Target tensor shape is torch.Size([5, 1, 224, 224])
The loss is Variable containing:
 0.1908
[torch.FloatTensor of size 1]



In [4]:
unet = UNET()
lungDataLoader = DataLoader(lung424, shuffle=True, batch_size=5)

lr = 0.001
momentum = 0.9
optimizer = optim.SGD(unet.parameters(), lr, momentum)
criterion = nn.MSELoss()

In [None]:
runing_loss = 0.0
for epoch in range(25):
    for i, data in enumerate(lungDataLoader):
        img, target = data
        print('Number {} batch with batch size {}'.format(i, img.size()))
        
        img, target = Variable(img), Variable(target)
        
        optimizer.zero_grad()
        out = unet(img)
        
        loss = criterion(out, target)
        loss.backward()
        
        optimizer.step()
        
        runing_loss += loss.data[0]
        print('[%d, %5d] loss: %.3f' % 
            (epoch + 1, i + 1, runing_loss))
        runing_loss = 0.0
        
#         if i % 200 == 199:
#             print(out)
#             print(target)
#             print('[%d, %5d] loss: %.3f' % 
#                 (epoch + 1, i + 1, runing_loss/200))
#             runing_loss = 0.0
    

Number 0 batch with batch size torch.Size([5, 3, 224, 224])
[1,     1] loss: 0.199
Number 1 batch with batch size torch.Size([5, 3, 224, 224])
[1,     2] loss: 0.187
Number 2 batch with batch size torch.Size([5, 3, 224, 224])
[1,     3] loss: 0.268
Number 3 batch with batch size torch.Size([5, 3, 224, 224])
[1,     4] loss: 0.171
Number 4 batch with batch size torch.Size([5, 3, 224, 224])
[1,     5] loss: 0.226
Number 5 batch with batch size torch.Size([5, 3, 224, 224])
[1,     6] loss: 0.162
Number 6 batch with batch size torch.Size([5, 3, 224, 224])
[1,     7] loss: 0.227
Number 7 batch with batch size torch.Size([5, 3, 224, 224])
[1,     8] loss: 0.148
Number 8 batch with batch size torch.Size([5, 3, 224, 224])
[1,     9] loss: 0.142
Number 9 batch with batch size torch.Size([5, 3, 224, 224])
[1,    10] loss: 0.179
Number 10 batch with batch size torch.Size([5, 3, 224, 224])
[1,    11] loss: 0.204
Number 11 batch with batch size torch.Size([5, 3, 224, 224])
[1,    12] loss: 0.252
Nu