In [1]:
import os
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as transforms

from PIL import Image

os.listdir('../test_r_unet/data/')

['test', 'labels', 'images']

In [2]:
# arguments
timesteps = 3
batch_size = 1
num_epoch = 1
input_size = 128

gru_nan = False

transform = transforms.Compose([
                              transforms.Resize((input_size, input_size)),
                              transforms.ToTensor(),
                              ])

# decive
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
# way to the data folders
folder_data = "../test_r_unet/data/images"
folder_mask = "../test_r_unet/data/labels"
folder_test = "../test_r_unet/data/test"

file_names = os.listdir('../test_r_unet/data/images')

In [4]:
class MedData(Dataset):
    def __init__(self):
        super().__init__()
        self.time = timesteps
        self.folder_data = folder_data
        self.folder_mask = folder_mask
        self.file_names = file_names
        
    def __getitem__(self, idx):
        gif_list = []
        for i in range(self.time):
            gif_list += transform(Image.open(self.folder_data + '/' + file_names[idx+i])).unsqueeze(0)
        gif_data = torch.stack(gif_list)
        gif_list.clear()
        for i in range(self.time):
            gif_list += transform(Image.open(self.folder_mask + '/' + file_names[idx+i])).unsqueeze(0)
        gif_mask = torch.stack(gif_list)
        gif_list.clear()
        return gif_data, gif_mask
    
    def __len__(self):
        return len(self.file_names) - self.time + 1

In [5]:
dataset = MedData()
train_loader = DataLoader(dataset=dataset,
                          batch_size=batch_size,
                          num_workers=2,
                          shuffle=True)

In [6]:
class GruCell(nn.Module):
    
    def __init__(self, channel):
        super(GruCell, self).__init__()
        self.conv_relu = nn.Sequential(nn.Conv2d(in_channels=channel, out_channels=channel, kernel_size=3, stride=1, padding=1),
                                       nn.ELU(),
                                       nn.Dropout(p=0.2))
        
        self.conv_relu_2x = nn.Sequential(nn.Conv2d(in_channels=channel+channel, out_channels=channel, kernel_size=3, stride=1, padding=1),
                                          nn.ELU(),
                                          nn.Dropout(p=0.2))
        
        
        self.sig = nn.Sigmoid()

     
    def forward(self, x, hidden):
        input = torch.cat([x, hidden],dim=1)

        update_gate = self.conv_relu_2x(input)
        update_gate = self.sig((update_gate)) ### output after update gate
        
        reset_gate = self.conv_relu_2x(input)
        reset_gate = self.sig((reset_gate)) ### output after reset gate
        
        memory_gate_for_input = self.conv_relu(x)
        memory_gate_for_hidden = self.conv_relu(hidden)

        memory_content = self.sig((memory_gate_for_input + (reset_gate * memory_gate_for_hidden))) ### output for reset gate(affects how the reset gate do work)
        
        hidden = (update_gate * hidden) + ((1 - update_gate) * memory_content) # torch.ones(input_size, hidden_size)

        return hidden, hidden

In [11]:
# create class Gru
class Gru(nn.Module):

    def __init__(self, channels_size):
        super(Gru, self).__init__()
        self.batch_size = batch_size
        self.timesteps = timesteps
        self.channels_size = channels_size
        self.input_size = input_size
        self.hidden_size = (self.batch_size, channels_size, input_size, input_size)
        
        self.gru_layer0 = GruCell(channels_size)
        self.init_hidden = torch.zeros(self.hidden_size).to(device)
        self.gru_nan = gru_nan


    def forward(self, x):
        x_cells = None
        x_list = []
        if self.gru_nan == False:
            try:
                x = x.reshape(batch_size, timesteps, self.channels_size, self.input_size, self.input_size)
                x = x.permute(1, 0, 2, 3, 4)
                for i in range(timesteps):
                    if x_cells is None:
                        x_cells, hidden = self.gru_layer0(x[i], self.init_hidden)
                        x_list.append(x_cells)
                    else:
                        x_i, hidden = self.gru_layer0(x[i], hidden)
                        x_list.append(x_i)
                x_cells = torch.stack(x_list)

             ##### FOR LAST BATCH
            except RuntimeError:
                x = x.reshape(1, timesteps, self.channels_size, self.input_size, self.input_size) #last batch is (15), but batch_size = 16, #arg.timesteps = 2 
                x = x.permute(1, 0, 2, 3, 4)
                hidden_zero = torch.zeros_like(x)
                for i in range(timesteps):
                    if x_cells is None:
                        x_cells, hidden = self.gru_layer0(x[i], hidden_zero[0])
                        x_list.append(x_cells)
                    else:
                        x_i, hidden = self.gru_layer0(x[i], hidden)
                        x_list.append(x_i)
                x_cells = torch.stack(x_list)
             #####
        elif self.gru_nan == True:
            try:
                x = x.reshape(batch_size, timesteps, self.channels_size, self.input_size, self.input_size)
                x = x.permute(1, 0, 2, 3, 4)
                for i in range(timesteps):
                    if x_cells is None:
                        x_cells, hidden = self.gru_layer0(x[i], self.init_hidden)
                        x_list.append(x_cells)
                    else:
                        x_i, hidden = self.gru_layer0(x[i], self.init_hidden)
                        x_list.append(x_i)
                x_cells = torch.stack((x_cells, x_i))
            ##### FOR LAST BATCH
            except RuntimeError:
                x = x.reshape(1, timesteps, self.channels_size, self.input_size, self.input_size) #last batch is (15), but batch_size = 16, #arg.timesteps = 2 
                x = x.permute(1, 0, 2, 3, 4)
                hidden_zero = torch.zeros_like(x)
                for i in range(timesteps):
                    if x_cells is None:
                        x_cells, hidden = self.gru_layer0(x[i], hidden_zero[0])
                        x_list.append(x_cells)
                    else:
                        x_i, hidden = self.gru_layer0(x[i], hidden)
                        x_list.append(x_i)
                x_cells = torch.stack((x_cells, x_i))
        else:
            print('gru_nan can be only True or False')
            quit()

        return x_cells  

In [12]:
class Conv3x3Small(nn.Module):
    def __init__(self, in_feat, out_feat):
        super(Conv3x3Small, self).__init__()

        self.conv1 = nn.Sequential(nn.Conv2d(in_feat, out_feat,
                                             kernel_size=3,
                                             stride=1,
                                             padding=1),
                                   nn.ELU(),
                                   nn.Dropout(p=0.2))

        self.conv2 = nn.Sequential(nn.Conv2d(out_feat, out_feat,
                                             kernel_size=3,
                                             stride=1,
                                             padding=1),
                                   nn.ELU())

    def forward(self, inputs):
        outputs = self.conv1(inputs)
        outputs = self.conv2(outputs)
        return outputs


class DoubleConv(nn.Module):

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ELU(),
            nn.Dropout(p=0.2),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ELU(),
            nn.Dropout(p=0.2)
        )

    def forward(self, x):
        return self.double_conv(x)

    
class Conv(nn.Module):

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ELU(),
            nn.Dropout(p=0.2)
        )

    def forward(self, x):
        return self.conv(x)


class UpConcat(nn.Module):
    def __init__(self, in_feat, out_feat):
        super(UpConcat, self).__init__()

        self.up = nn.UpsamplingBilinear2d(scale_factor=2)

        # self.deconv = nn.ConvTranspose2d(in_feat, out_feat,
        #                                  kernel_size=3,
        #                                  stride=1,
        #                                  dilation=1)

        self.deconv = nn.ConvTranspose2d(in_feat,
                                         out_feat,
                                         kernel_size=2,
                                         stride=2)

    def forward(self, inputs, down_outputs):
        # TODO: Upsampling required after deconv?
        # outputs = self.up(inputs)
        outputs = self.deconv(inputs)
        out = torch.cat([down_outputs, outputs], 1)
        return out


class UpSample(nn.Module):
    def __init__(self, in_feat, out_feat):
        super(UpSample, self).__init__()

        self.up = nn.Upsample(scale_factor=2, mode='nearest')

        self.deconv = nn.ConvTranspose2d(in_feat,
                                         out_feat,
                                         kernel_size=2,
                                         stride=2)

    def forward(self, inputs, down_outputs):
        # TODO: Upsampling required after deconv?
        outputs = self.up(inputs)
        # outputs = self.deconv(inputs)
        out = torch.cat([outputs, down_outputs], 1)
        return out

In [26]:
class UNetSmall(nn.Module):
    def __init__(self, num_channels=1, num_classes=2):
        super(UNetSmall, self).__init__()
        num_feat = [32, 64, 128, 256]
        self.input_size = input_size

        self.down1 = nn.Sequential(DoubleConv(num_channels, num_feat[0]))

        self.down2 = nn.Sequential(nn.MaxPool2d(kernel_size=2),
                                   Gru(num_feat[0]),
                                   Conv(num_feat[0], num_feat[1]))

        self.down3 = nn.Sequential(nn.MaxPool2d(kernel_size=2),
                                   Gru(num_feat[1]),
                                   Conv(num_feat[1], num_feat[2]))

        self.bottom = nn.Sequential(nn.MaxPool2d(kernel_size=2),
                                   Gru(num_feat[2]),
                                   Conv(num_feat[2], num_feat[3]))

        self.up1 = UpSample(num_feat[3], num_feat[2])
        self.upconv1 = nn.Sequential(Conv3x3Small(num_feat[3] + num_feat[2], num_feat[2]),
                                     nn.BatchNorm2d(num_feat[2]))

        self.up2 = UpSample(num_feat[2], num_feat[1])
        self.upconv2 = nn.Sequential(Conv3x3Small(num_feat[2] + num_feat[1], num_feat[1]),
                                     nn.BatchNorm2d(num_feat[1]))

        self.up3 = UpSample(num_feat[1], num_feat[0])
        self.upconv3 = nn.Sequential(Conv3x3Small(num_feat[1] + num_feat[0], num_feat[0]),
                                     nn.BatchNorm2d(num_feat[0]))

        self.final = nn.Sequential(nn.Conv2d(num_feat[0],
                                             1,
                                             kernel_size=1),
                                   nn.Sigmoid())

    def forward(self, inputs, return_features=False):
        inputs = inputs.reshape(-1, 1, self.input_size, self.input_size)
        print(inputs.data.size())
        down1_feat = self.down1(inputs)
        # print(down1_feat.size())
        down2_feat = self.down2(down1_feat)
        # print(down2_feat.size())
        down3_feat = self.down3(down2_feat)
        # print(down3_feat.size())
        bottom_feat = self.bottom(down3_feat)

        # print(bottom_feat.size())
        up1_feat = self.up1(bottom_feat, down3_feat)
        # print(up1_feat.size())
        up1_feat = self.upconv1(up1_feat)
        # print(up1_feat.size())
        up2_feat = self.up2(up1_feat, down2_feat)
        # print(up2_feat.size())
        up2_feat = self.upconv2(up2_feat)
        # print(up2_feat.size())
        up3_feat = self.up3(up2_feat, down1_feat)
        # print(up3_feat.size())
        up3_feat = self.upconv3(up3_feat)
        # print(up3_feat.size())

        if return_features:
            outputs = up3_feat
        else:
            outputs = self.final(up3_feat)

        return outputs

In [27]:
model = UNetSmall()

In [30]:
x = torch.rand(1,3,1,128,128)

In [31]:
model(x)

torch.Size([3, 1, 128, 128])


RuntimeError: shape '[1, 3, 32, 128, 128]' is invalid for input of size 393216