In [1]:
import os
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image
from torch.autograd import Variable
import time

In [2]:
batch_size = 100
transform=transforms.Compose([
    transforms.Resize((80,80)),
    transforms.ToTensor()])
train_dataset = torchvision.datasets.MNIST(root='../../data',
                                           train=True,
                                           transform=transform)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True)

In [3]:
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm2d') != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)


class UNetDown(nn.Module):
    def __init__(self, in_size, out_size, normalize=True, dropout=0.0):
        super(UNetDown, self).__init__()
        layers = [nn.Conv2d(in_size, out_size, kernel_size=3, stride=1,
                            padding=0, bias=False),
                  nn.ReLU(),
                  nn.Conv2d(out_size, out_size, kernel_size=3, stride=1,
                            padding=0, bias=False)]
        if normalize:
            layers.append(nn.InstanceNorm2d(out_size))
        layers.append(nn.ReLU())
        if dropout:
            layers.append(nn.Dropout(dropout))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

#up phase convtranspose2d
class UNetUp(nn.Module):
    def __init__(self, in_size, out_size, dropout=0.0):
        super(UNetUp, self).__init__()
        layers = [nn.ConvTranspose2d(in_size, out_size, 4, 2, 1, bias=False)]

        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)


class UnetUpConv(nn.Module):
    def __init__(self, in_size, out_size, normalize=True, dropout=0.0):
        super(UnetUpConv, self).__init__()
        layers = [nn.Conv2d(in_size, out_size, kernel_size=3,
                            stride=1, padding=0, bias=False),
                  nn.ReLU(),
                  nn.Conv2d(out_size, out_size, kernel_size=3,
                            stride=1, padding=0, bias=False)
                  ]
        if normalize:
            layers.append(nn.InstanceNorm2d(out_size))
        layers.append(nn.ReLU())
        if dropout:
            layers.append(nn.Dropout(dropout))

        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

In [4]:
def center_crop_concat(input_x, input_y):
    width = input_x.shape[2]-input_y.shape[2]
    start_point = int(width/2)
    x_crop = input_x[:, :, start_point:start_point +
                     input_y.shape[2], start_point:start_point+input_y.shape[2]]
    out = torch.cat((x_crop, input_y), 1)
    return out


In [18]:

class Unet_model(nn.Module):
    def __init__(self,in_channel=1):
        super(Unet_model,self).__init__()
        self.down1=UNetDown(in_channel,64)
        self.down2=UNetDown(64,128)
        self.down3=UNetDown(128,256)
        
        self.max_pool=nn.MaxPool2d(kernel_size=2,stride=2)
        
        self.up1=UNetUp(256,128)
        self.up1_cnn=UnetUpConv(256,128)
        
        self.up2=UNetUp(128,64)
        self.up2_cnn=UnetUpConv(128,64)
        
        self.up3=UNetUp(64,32)
        self.up3_cnn=UnetUpConv(64,32)
        
        self.final=UnetUpConv(64,1)
        
    def forward(self,x):
        o1=self.down1(x)
        o1_max=self.max_pool(o1)
        o2=self.down2(o1_max)
        o2_max=self.max_pool(o2)
        o3=self.down3(o2_max)
        
        u1=self.up1(o3) # 128 특징 가지고 있다.
        u1_concat=center_crop_concat(o2,u1)
        u1_conv=self.up1_cnn(u1_concat)
        
        u2=self.up2(u1_conv)
        u2_concat=center_crop_concat(o1,u2)
        u2_conv=self.up2_cnn(u2_concat)
        
        output=self.final(u2_conv)
        
#         print("************Unet down model************")
#         print("layer1 : ",o1.shape)
#         print("layer2 : ",o2.shape)
#         print("layer3 : ",o3.shape)   
#         print("************Unet up model************")
#         print("layer2 : ",u1_conv.shape)
#         print("layer1 : ",u2_conv.shape)
#         print("final : ",output.shape)
        return output
        


In [19]:
model=Unet_model().cuda()

for i ,[images,_]in enumerate(train_loader):
    img=Variable(images).cuda()
    print(img.shape)
    
    im=model.forward(img)
    print(im.shape)
    if i==0:
        break

torch.Size([100, 1, 80, 80])
torch.Size([100, 1, 36, 36])


In [26]:
save_image(im,'Unet_image.png')