In [10]:
import os
import numpy as np

import scipy.ndimage.morphology as morph

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as transforms

from PIL import Image

to_pil = transforms.ToPILImage()
os.listdir('../r_unet/data/')

['images_val', 'labels_val', 'test', 'labels', 'images']

In [11]:
DEVICE = "cuda:0"
# arguments
TIMESTEPS = 3
BATCH_SIZE = 1
NUM_EPOCH = 50
INPUT_SIZE = 128
INPUT_CHANNELS = 1
NUM_CLASSES = 2
LEARNING_RATE = 0.001

RECURRENT = True

transform = transforms.Compose([
                              transforms.Resize((128, 128), interpolation = 0),
                              transforms.ToTensor()
                              ])

to_tensor = transforms.ToTensor()

# decive
device = torch.device(DEVICE if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [7]:
# way to the data folders
FOLDER_DATA = "../r_unet/data/images"
FOLDER_MASK = "../r_unet/data/labels"
FOLDER_TEST = "../r_unet/data/test"
FOLDER_DATA_VAL = "../r_unet/data/images_val"
FOLDER_MASK_VAL = "../r_unet/data/labels_val"

FILE_NAMES = sorted(os.listdir('../r_unet/data/images'))
FILE_NAMES_VAL = sorted(os.listdir('../r_unet/data/images_val'))

In [126]:
def get_labels(object):
    label1 = (object==0).float()
    label2 = (label1==0).float()
    labels = torch.stack([label1, label2], dim=1)
    return labels

In [127]:
class TrainMedData(Dataset):
    def __init__(self):
        super().__init__()
        self.time = TIMESTEPS
        self.folder_data = FOLDER_DATA
        self.folder_mask = FOLDER_MASK
        self.file_names = FILE_NAMES

    def __getitem__(self, idx):
        gif_list = []
        for i in range(self.time):
            gif_list.append(transform(Image.open(self.folder_data + '/' + self.file_names[idx+i])))
        gif_data = torch.stack(gif_list)
        gif_list.clear()
        for i in range(self.time):
            gif_list.append(get_labels(transform(Image.open(self.folder_mask + '/' + self.file_names[idx+i]))).unsqueeze(0))
        gif_mask = torch.stack(gif_list).squeeze(dim=2)
        gif_list.clear()
        for i in range(self.time):
            img = Image.open(self.folder_mask + '/' + self.file_names[idx+i])
            img = img.resize((128, 128), resample=Image.NEAREST)
            gif_list.append(to_tensor(morph.distance_transform_edt(np.asarray(img)/255)).unsqueeze(0))
        gif_depth = torch.stack(gif_list)
        return gif_data, gif_mask, gif_depth
    
    def __len__(self):
        return len(self.file_names) - self.time + 1


class ValMedData(Dataset):
    def __init__(self):
        super().__init__()
        self.time = TIMESTEPS
        self.folder_data = FOLDER_DATA_VAL
        self.folder_mask = FOLDER_MASK_VAL
        self.file_names = FILE_NAMES_VAL

    def __getitem__(self, idx):
        gif_list = []
        for i in range(self.time):
            gif_list.append(transform(Image.open(self.folder_data + '/' + self.file_names[idx+i])))
        gif_data = torch.stack(gif_list)
        gif_list.clear()
        for i in range(self.time):
            gif_list.append(get_labels(transform(Image.open(self.folder_mask + '/' + self.file_names[idx+i]))))
        gif_mask = torch.stack(gif_list).squeeze(dim=2)
        gif_list.clear()
        for i in range(self.time):
            img = Image.open(self.folder_mask + '/' + self.file_names[idx+i])
            img = img.resize((128, 128), resample=Image.NEAREST)
            gif_list.append(to_tensor(morph.distance_transform_edt(np.asarray(img)/255)))
        gif_depth = torch.stack(gif_list)
        return gif_data, gif_mask, gif_depth

    def __len__(self):
        return len(self.file_names) - self.time + 1


class TestMedData(Dataset):
    def __init__(self):
        super().__init__
        self.time = TIMESTEPS
        self.folder_test = FOLDER_TEST
        self.file_names = FILE_NAMES + FILE_NAMES_VAL

    def __getitem__(self, idx):
        gif_list = []
        for i in range(self.time):
            gif_list.append(transform(Image.open(self.folder_test + '/' + self.file_names[idx+i])))
        gif_test = torch.stack(gif_list)
        gif_list.clear()
        return gif_test

    def __len__(self):
        return len(self.file_names) - self.time + 1

In [128]:
train_dataset = TrainMedData()
valid_dataset = ValMedData()
test_dataset = TestMedData()

train_loader = DataLoader(dataset=train_dataset,
                          batch_size=BATCH_SIZE,
                          num_workers=2,
                          shuffle=False)

valid_loader = DataLoader(dataset=valid_dataset,
                          batch_size=BATCH_SIZE,
                          num_workers=2,
                          shuffle=False)

test_loader = DataLoader(dataset=test_dataset,
                         batch_size=1,
                         num_workers=2,
                         shuffle=False)

data_loaders = {
    'train' : train_loader,
    'valid' : valid_loader,
    'test' : test_loader
}

dataset_sizes = {
    'train': len(train_dataset),
    'valid': len(valid_dataset),
    'test': len(test_dataset)
}

In [133]:
class ConvRelu(nn.Module):

    def __init__(self, in_channels, out_chennels):
        super(ConvRelu, self).__init__()
        self.convrelu = nn.Sequential(nn.Conv2d(in_channels, out_chennels),
                                      nn.ReLU())

    def forward(self, x):
        return self.convrelu(x)


class MaxPool(nn.Module):

    def __init__(self):
        super(MaxPool, self).__init__()
        self.maxpool = nn.MaxPool2d(kernel_size=2)
        
    def forward(self, x):
        return self.maxpool(x)    


class UpAndCat(nn.Module):
    
    def __init__(self):
        super(UpAndCat, self).__init__()
        self.up = nn.Upsample(scale_factor=2, mode='nearest')

    def forward(self, x_up, x_cat):
        out = self.up(x_up)
        out = torch.cat([out, x_cat], 1)
        return out

In [136]:
class ConvRnnCell(nn.Module):
    
    def __init__(self, channel):
        super(ConvRnnCell, self).__init__()
        self.conv1 = nn.Sequential(nn.Conv2d(channel+channel, channel, kernel_size=3, padding=1))
             
    def forward(self, x, hidden):
        out = torch.cat([x, hidden],dim=1)
        out = self.conv1(out)
        hidden = out
        return out, hidden

In [135]:
# create class ConvRnn
class ConvRnn(nn.Module):

    def __init__(self, channels_size, ConvRnn_input_size): # arg for ConvRnn layer
        super(ConvRnn, self).__init__()
        self.batch_size = BATCH_SIZE
        self.timesteps = TIMESTEPS
        self.channels_size = channels_size
        self.input_size = ConvRnn_input_size
        self.hidden_size = (self.batch_size, channels_size, ConvRnn_input_size, ConvRnn_input_size)
        
        self.ConvRnn_layer = ConvRnnCell(channels_size)
        self.init_hidden = torch.zeros(self.hidden_size).to(device)
        self.gru_nan = gru_nan


    def forward(self, x):
        x_cells = None
        x_list = []

        x = x.reshape(self.batch_size, self.timesteps, self.channels_size, self.input_size, self.input_size)
        x = x.permute(1, 0, 2, 3, 4)
        for i in range(self.timesteps):
            if x_cells is None:
                x_cells, hidden = self.ConvRnn_layer(x[i], self.init_hidden)
                x_list.append(x_cells)
            else:
                x_i, hidden = self.ConvRnn_layer(x[i], hidden)
                x_list.append(x_i)
        x_cells = torch.stack(x_list)

        x_cells = x_cells.reshape(-1, self.channels_size, self.input_size, self.input_size)
        return x_cells  

In [None]:
class UNetDesigner(nn.Module):
    
    def __init__(self, down1, down2, down3, bottom, up1, up2, up3,
                 input_size=INPUT_SIZE, input_channels=INPUT_CHANNELS, num_classes=NUM_CLASSES):
        super(UNetDesigner, self).__init__()
        self.input_size = input_size
        self.input_chennels = input_channels

    def forward(self, x):
        x = x.reshape(-1, self.input_chennels, self.input_size, self.input_size)