In [1]:
import os
import numpy as np
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import matplotlib.pyplot as plt
import torch.nn.UpsamplingNearest2d as UpsamplingNearest2d



class UNET(nn.Module):
    def __init__(self):
        super(UNET, self).__init__()
        
        #conv block 1
        self.conv1_1 = nn.Conv2d(3, 32, 3, padding=1, stride=1, dilation=1)
        self.conv1_2 = nn.Conv2d(3, 32, 3, padding=1, stride=1, dilation=1)
        self.conv1_3 = nn.Conv2d(3, 32, 3, padding=1, stride=1, dilation=1)
        
        #conv block 2
        self.conv2_1 = nn.Conv2d(32, 64, 3, padding=1, stride=1, dilation=1)
        self.conv2_2 = nn.Conv2d(32, 64, 3, padding=1, stride=1, dilation=1)
        self.conv2_3 = nn.Conv2d(32, 64, 3, padding=1, stride=1, dilation=1)
        
        #conv block 3
        self.conv3_1 = nn.Conv2d(64, 128, 3, padding=1, stride=1, dilation=1)
        self.conv3_2 = nn.Conv2d(64, 128, 3, padding=1, stride=1, dilation=1)
        self.conv3_3 = nn.Conv2d(64, 128, 3, padding=1, stride=1, dilation=1)        
        
        #conv block 4
        self.conv4_1 = nn.Conv2d(128, 256, 3, padding=1, stride=1, dilation=1)
        self.conv4_2 = nn.Conv2d(128, 256, 3, padding=1, stride=1, dilation=1)
        self.conv4_3 = nn.Conv2d(128, 256, 3, padding=1, stride=1, dilation=1)
        
        #reverse conv block 1
        self.deconv1_1 = nn.Conv2d(256, 128, 3, padding=1, stride=1, dilation=1)
        self.deconv1_2 = nn.Conv2d(256, 128, 3, padding=1, stride=1, dilation=1)
        self.deconv1_3 = nn.Conv2d(256, 128, 3, padding=1, stride=1, dilation=1)
        
        #reverse conv block 2
        self.deconv2_1 = nn.Conv2d(128, 64, 3, padding=1, stride=1, dilation=1)
        self.deconv2_2 = nn.Conv2d(128, 64, 3, padding=1, stride=1, dilation=1)
        self.deconv2_3 = nn.Conv2d(128, 64, 3, padding=1, stride=1, dilation=1)
        
        #reverse conv block 3
        self.deconv3_1 = nn.Conv2d(64, 32, 3, padding=1, stride=1, dilation=1)
        self.deconv3_2 = nn.Conv2d(64, 32, 3, padding=1, stride=1, dilation=1)
        self.deconv3_3 = nn.Conv2d(64, 32, 3, padding=1, stride=1, dilation=1)
        
        #reverse conv block 4
        self.deconv4_1 = nn.Conv2d(32, 3, 3, padding=1, stride=1, dilation=1)
        self.deconv4_2 = nn.Conv2d(32, 3, 3, padding=1, stride=1, dilation=1)
        self.deconv4_3 = nn.Conv2d(32, 3, 3, padding=1, stride=1, dilation=1)
        
        self.upsample = UpsamplingNearest2d(scale_factor=2)
        
    def forward(self, x):
        # init conv
        x = F.relu(self.conv1_1(x))
        x = F.relu(self.conv1_2(x), 2)
        conv1_feature = F.relu(self.conv1_3(x))
        x = F.relu(F.max_pool2d(self.conv1_3(x), 2))
        
        x = F.relu(self.conv2_1(x))
        x = F.relu(self.conv2_2(x), 2)
        conv2_feature = F.relu(self.conv2_3(x))
        x = F.relu(F.max_pool2d(self.conv2_3(x), 2))
        
        x = F.relu(self.conv3_1(x))
        x = F.relu(self.conv3_2(x), 2)
        conv3_feature = F.relu(self.conv3_3(x))
        x = F.relu(F.max_pool2d(self.conv3_3(x), 2))
        
        x = F.relu(self.conv4_1(x))
        x = F.relu(self.conv4_2(x), 2)
        x = F.relu(F.max_pool2d(self.conv4_3(x), 2))
        
        
        # init deconv
        # deconv change image size => no
        # concate conv feature
        
        x = F.relu(self.deconv1_1(x))
        x = F.relu(self.deconv1_2(x), 2)
        x = F.relu(F.biltperlate(self.deconv1_3(x), 2))
        
        x = F.relu(self.deconv2_1(x))
        x = F.relu(self.deconv2_2(x), 2)
        x = F.relu(F.biltperlate(self.deconv2_3(x), 2))
        
        x = F.relu(self.deconv3_1(x))
        x = F.relu(self.deconv3_2(x), 2)
        x = F.relu(F.biltperlate(self.deconv3_3(x), 2))
        
        x = F.relu(self.deconv4_1(x))
        x = F.relu(self.deconv4_2(x), 2)
        x = F.relu(F.biltperlate(self.deconv4_3(x), 2))
        

        

SyntaxError: invalid syntax (<ipython-input-1-9b4dff25645d>, line 15)