In [8]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchinfo import summary
from torchviz import make_dot

In [4]:
class unet(nn.Module): # nn.Module 상속
    def __init__(self):
        super(unet, self).__init__()

        def cv2d(input_channels, output_channels, kernel_size = 3, stride = 1, padding = 1, bias = True):
            layers = []
            layers += [nn.Conv2d(input_channels, output_channels,
                                 kernel_size = kernel_size, stride = stride, padding = padding, bias = bias)]
            layers += [nn.BatchNorm2d(num_features = output_channels)]
            layers += [nn.ReLU()]      

            cbr = nn.Sequential(*layers)

            return cbr
        
        # contract path
        
        self.enc1_1 = cv2d(1, 64)       
        self.enc1_2 = cv2d(64, 64)

        self.mxpool1 = nn.MaxPool2d(kernel_size = 2)

        self.enc2_1 = cv2d(64, 128)
        self.enc2_2 = cv2d(128, 128)

        self.mxpool2 = nn.MaxPool2d(kernel_size = 2)
        
        self.enc3_1 = cv2d(128, 256)
        self.enc3_2 = cv2d(256, 256)

        self.mxpool3 = nn.MaxPool2d(kernel_size = 2)

        self.enc4_1 = cv2d(256, 512)
        self.enc4_2 = cv2d(512, 512)

        self.mxpool4 = nn.MaxPool2d(kernel_size = 2)

        self.enc5_1 = cv2d(512, 1024)

        # expand path

        self.dec5_1 = cv2d(1024, 512)

        self.upconv4 = nn.ConvTranspose2d(512, 512,
                                         kernel_size = 2, stride = 2, padding = 0, bias = True)

        self.dec4_2 = cv2d(2 * 512, 512)                                
        self.dec4_1 = cv2d(512, 256)

        self.upconv3 = nn.ConvTranspose2d(256, 256,
                                         kernel_size = 2, stride = 2, padding = 0, bias = True)  

        self.dec3_2 = cv2d(2 * 256, 256)                                                           
        self.dec3_1 = cv2d(256, 128)

        self.upconv2 = nn.ConvTranspose2d(128, 128,
                                          kernel_size = 2, stride = 2, padding = 0, bias = True) 

        self.dec2_2 = cv2d(2 * 128, 128)                                                           
        self.dec2_1 = cv2d(128, 64)                 

        self.upconv1 = nn.ConvTranspose2d(64, 64,
                                          kernel_size = 2, stride = 2, padding = 0, bias = True) 

        self.dec1_2 = cv2d(2 * 64, 64)                                                           
        self.dec1_1 = cv2d(64, 64)

        self.fc = nn.Conv2d(in_channels = 64, out_channels = 2, kernel_size = 1, stride = 1, padding = 0, bias = True)

    def forward(self, x):
        enc1_1 = self.enc1_1(x)
        enc1_2 = self.enc1_2(enc1_1)
        enc_pool1 = self.mxpool1(enc1_2)

        enc2_1 = self.enc2_1(enc_pool1)
        enc2_2 = self.enc2_2(enc2_1)
        enc_pool2 = self.mxpool2(enc2_2)

        enc3_1 = self.enc3_1(enc_pool2)
        enc3_2 = self.enc3_2(enc3_1)
        enc_pool3 = self.mxpool3(enc3_2)

        enc4_1 = self.enc4_1(enc_pool3)
        enc4_2 = self.enc4_2(enc4_1)
        enc_pool4 = self.mxpool4(enc4_2)    

        enc5_1 = self.enc5_1(enc_pool4)   

        dec5_1 = self.dec5_1(enc5_1)
        
        dec_pool4 = self.upconv4(dec5_1)
        cat4 = torch.cat((dec_pool4, enc4_2), dim = 1)

        dec4_2 = self.dec4_2(cat4)
        dec4_1 = self.dec4_1(dec4_2)

        dec_pool3 = self.upconv3(dec4_1)
        cat3 = torch.cat((dec_pool3, enc3_2), dim = 1)
        dec3_2 = self.dec3_2(cat3)     
        dec3_1 = self.dec3_1(dec3_2) 

        dec_pool2 = self.upconv2(dec3_1)
        cat2 = torch.cat((dec_pool2, enc2_2), dim = 1)
        dec2_2 = self.dec2_2(cat2)     
        dec2_1 = self.dec2_1(dec2_2)   

        dec_pool1 = self.upconv1(dec2_1)
        cat1 = torch.cat((dec_pool1, enc1_2), dim = 1) # dim = 0 : 배치방향, 1 : 채널방향, 2 : y 방향, 3 : x 방향
        dec1_2 = self.dec1_2(cat1)     
        dec1_1 = self.dec1_1(dec1_2) 

        x = self.fc(dec1_1)  

        return x

In [5]:
model = unet()

In [None]:
lr = 1e-2
batch_size = 32
num_epoch = 100

In [None]:
class dataset(torch.utils.data.Dataset):
    def __init__(self, data_dir, transform = None):
       self.data_dir = data_dir
       self.transform = transform
    

