In [1]:
import torch
import torch.nn as nn
from dataloader import dataset
from torch.utils.data import Dataset, DataLoader

In [2]:
torch.manual_seed(0)

<torch._C.Generator at 0x1c2c3de78b0>

In [3]:
# Generate training, validation and test datasets
# random split
train_set_size = int(len(dataset)*0.6)
valid_set_size = int(len(dataset)*0.2)
test_set_size = len(dataset)-train_set_size-valid_set_size
train_set, valid_set, test_set = torch.utils.data.random_split(dataset, [train_set_size, valid_set_size, test_set_size])

In [4]:
# start to load from customized DataSet object
# bring batch size, iterations and epochs together
batch_size = 5
train_loader = DataLoader(train_set,batch_size,shuffle=False,drop_last=True)
valid_loader = DataLoader(valid_set,batch_size,shuffle=False,drop_last=True)
test_loader = DataLoader(test_set,batch_size,shuffle=False,drop_last=True)

In [5]:
# Implement UNET
# https://medium.com/analytics-vidhya/unet-implementation-in-pytorch-idiot-developer-da40d955f201
class conv_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.conv1 = nn.Conv2d(in_c, out_c, kernel_size=3, padding=1)
        # self.bn1 = nn.BatchNorm2d(out_c)
        
        self.conv2 = nn.Conv2d(out_c, out_c, kernel_size=3, padding=1)
        # self.bn2 = nn.BatchNorm2d(out_c)
        
        self.relu = nn.RELU()
        
    def forward(self, inputs):
            x = self.conv1(inputs)
            # x = self.bn1(x)
            x = self.relu(x)
            
            x = self.conv2(x)
            # x = self.bn2(x)
            x = self.relu(x)
            return x


In [6]:
# Encoder Block
class encoder_block(nn.Module):
    '''
    Output:
    x: output of conv_block, input of pooling layer
    p: output of pooling layer
    '''
    def __init__(self, in_c, out_c):
        super.__init__()
        self.conv = conv_block(in_c, out_c)
        self.pool = nn.MaxPool2d((2,2))
        
    def forward(self, inputs):
        x = self.conv(inputs)
        p = self.pool(x)
        return x, p

In [7]:
# Decoder Block
class decoder_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_c, out_c, kernel_size=2, stride=2, padding=0)
        self.conv = conv_block(out_c+out_c, out_c)
    
    def forward(self, inputs, skip):
        x = self.up(inputs)
        x = torch.cat([x, skip], axis=1)
        x = self.conv(x)
        return x    

In [1]:
# UNET Architecture
class unet(nn.Module):
    def __init__(self):
        super().__init()
        # Encoder
        self.e1 = encoder_block(3,64)
        self.e2 = encoder_block(64,128)
        self.e3 = encoder_block(128,256)
        self.e4 = encoder_block(256,512)
        
        # Bottleneck
        self.b = conv_block(512,1024)
        
        # Decoder
        self.d1 = decoder_block(1024,512)
        self.d2 = decoder_block(512,256)
        self.d3 = decoder_block(256,128)
        self.d4 = decoder_block(128,64)
        
        # Classifier
        self.outputs = nn.Conv(64,1, kernel_size=1, padding=0)
        
    def forward(self, inputs):
        # Encoder
        s1, p1 = self.e1(inputs)
        print(p1.shape)
        s2, p2 = self.e1(p1)
        print(p2.shape)
        s3, p3 = self.e1(p2)
        print(p3.shape)
        s4, p4 = self.e1(p3)
        
        print(p4.shape)
        
        # Bottleneck
        b = self.b(p4)
        print(s4.shape)
        
        # Decoder
        d1 = self.d1(b, s4)
        print(d1.shape)
        d2 = self.d1(d1, s3)
        print(d2.shape)
        d3 = self.d1(d2, s2)
        print(d3.shape)
        d4 = self.d1(d3, s1)
        
        # Classifier
        outputs = self.outputs(64)
        print(outputs.shape)
        
        return outputs

NameError: name 'nn' is not defined