In [1]:
import os
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as transforms

from PIL import Image

os.listdir('../test_r_unet/data/')

['labels', 'test', 'images']

In [5]:
# arguments
timesteps = 3
batch_size = 4
num_epochs = 10
input_size = 128

gru_nan = False

transform = transforms.Compose([
                              transforms.Resize((input_size, input_size)),
                              transforms.ToTensor(),
                              ])

# decive
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [6]:
# way to the data folders
folder_data = "../test_r_unet/data/images"
folder_mask = "../test_r_unet/data/labels"
folder_test = "../test_r_unet/data/test"

file_names = os.listdir('../test_r_unet/data/images')

In [7]:
class MedData(Dataset):
    def __init__(self):
        super().__init__()
        self.time = timesteps
        self.folder_data = folder_data
        self.folder_mask = folder_mask
        self.file_names = file_names
        
    def __getitem__(self, idx):
        gif_list = []
        for i in range(self.time):
            gif_list.append(transform(Image.open(self.folder_data + '/' + file_names[idx+i])).unsqueeze(0))
        gif_data = torch.stack(gif_list)
        gif_list.clear()
        for i in range(self.time):
            gif_list.append(transform(Image.open(self.folder_mask + '/' + file_names[idx+i])).unsqueeze(0))
        gif_mask = torch.stack(gif_list)
        gif_list.clear()
        return gif_data, gif_mask
    
    def __len__(self):
        return len(self.file_names) - self.time + 1

In [8]:
dataset = MedData()
train_loader = DataLoader(dataset=dataset,
                          batch_size=batch_size,
                          num_workers=2,
                          shuffle=True)

In [9]:
class GruCell(nn.Module):
    
    def __init__(self, channel):
        super(GruCell, self).__init__()
        self.conv_relu = nn.Sequential(nn.Conv2d(in_channels=channel, out_channels=channel, kernel_size=3, stride=1, padding=1),
                                       nn.ELU(),
                                       nn.Dropout(p=0.2))
        
        self.conv_relu_2x = nn.Sequential(nn.Conv2d(in_channels=channel+channel, out_channels=channel, kernel_size=3, stride=1, padding=1),
                                          nn.ELU(),
                                          nn.Dropout(p=0.2))
        
        
        self.sig = nn.Sigmoid()

     
    def forward(self, x, hidden):
        input = torch.cat([x, hidden],dim=1)

        update_gate = self.conv_relu_2x(input)
        update_gate = self.sig((update_gate)) ### output after update gate
        
        reset_gate = self.conv_relu_2x(input)
        reset_gate = self.sig((reset_gate)) ### output after reset gate
        
        memory_gate_for_input = self.conv_relu(x)
        memory_gate_for_hidden = self.conv_relu(hidden)

        memory_content = self.sig((memory_gate_for_input + (reset_gate * memory_gate_for_hidden))) ### output for reset gate(affects how the reset gate do work)
        
        hidden = (update_gate * hidden) + ((1 - update_gate) * memory_content) # torch.ones(input_size, hidden_size)

        return hidden, hidden

In [38]:
# create class Gru
class Gru(nn.Module):

    def __init__(self, channels_size, gru_input_size): # arg for gru layer
        super(Gru, self).__init__()
        self.batch_size = batch_size
        self.timesteps = timesteps
        self.channels_size = channels_size
        self.input_size = gru_input_size
        self.hidden_size = (self.batch_size, channels_size, gru_input_size, gru_input_size)
        
        self.gru_layer0 = GruCell(channels_size)
        self.init_hidden = torch.zeros(self.hidden_size).to(device)
        self.gru_nan = gru_nan


    def forward(self, x):
        x_cells = None
        x_list = []
        if self.gru_nan == False:
            try:
                x = x.reshape(batch_size, timesteps, self.channels_size, self.input_size, self.input_size)
                x = x.permute(1, 0, 2, 3, 4)
                for i in range(timesteps):
                    if x_cells is None:
                        print(x[i].shape, self.init_hidden.shape)
                        x_cells, hidden = self.gru_layer0(x[i], self.init_hidden)
                        x_list.append(x_cells)
                    else:
                        x_i, hidden = self.gru_layer0(x[i], hidden)
                        x_list.append(x_i)
                x_cells = torch.stack(x_list)

             ##### FOR LAST BATCH
            except RuntimeError:
                x = x.reshape(1, timesteps, self.channels_size, self.input_size, self.input_size) #last batch is (15), but batch_size = 16, #arg.timesteps = 2 
                x = x.permute(1, 0, 2, 3, 4)
                hidden_zero = torch.zeros_like(x)
                for i in range(timesteps):
                    if x_cells is None:
                        x_cells, hidden = self.gru_layer0(x[i], hidden_zero[0])
                        x_list.append(x_cells)
                    else:
                        x_i, hidden = self.gru_layer0(x[i], hidden)
                        x_list.append(x_i)
                x_cells = torch.stack(x_list)
             #####
        elif self.gru_nan == True:
            try:
                x = x.reshape(batch_size, timesteps, self.channels_size, self.input_size, self.input_size)
                x = x.permute(1, 0, 2, 3, 4)
                for i in range(timesteps):
                    if x_cells is None:
                        x_cells, hidden = self.gru_layer0(x[i], self.init_hidden)
                        x_list.append(x_cells)
                    else:
                        x_i, hidden = self.gru_layer0(x[i], self.init_hidden)
                        x_list.append(x_i)
                x_cells = torch.stack((x_cells, x_i))
            ##### FOR LAST BATCH
            except RuntimeError:
                x = x.reshape(1, timesteps, self.channels_size, self.input_size, self.input_size) #last batch is (15), but batch_size = 16, #arg.timesteps = 2 
                x = x.permute(1, 0, 2, 3, 4)
                hidden_zero = torch.zeros_like(x)
                for i in range(timesteps):
                    if x_cells is None:
                        x_cells, hidden = self.gru_layer0(x[i], hidden_zero[0])
                        x_list.append(x_cells)
                    else:
                        x_i, hidden = self.gru_layer0(x[i], hidden)
                        x_list.append(x_i)
                x_cells = torch.stack((x_cells, x_i))
        else:
            print('gru_nan can be only True or False')
            quit()
        x_cells = x_cells.reshape(-1, self.channels_size, self.input_size, self.input_size)

        return x_cells  

In [39]:
class Conv3x3Small(nn.Module):
    def __init__(self, in_feat, out_feat):
        super(Conv3x3Small, self).__init__()

        self.conv1 = nn.Sequential(nn.Conv2d(in_feat, out_feat,
                                             kernel_size=3,
                                             stride=1,
                                             padding=1),
                                   nn.ELU(),
                                   nn.Dropout(p=0.2))

        self.conv2 = nn.Sequential(nn.Conv2d(out_feat, out_feat,
                                             kernel_size=3,
                                             stride=1,
                                             padding=1),
                                   nn.ELU())

    def forward(self, inputs):
        outputs = self.conv1(inputs)
        outputs = self.conv2(outputs)
        return outputs


class DoubleConv(nn.Module):

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ELU(),
            nn.Dropout(p=0.2),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ELU(),
            nn.Dropout(p=0.2)
        )

    def forward(self, x):
        return self.double_conv(x)

    
class Conv(nn.Module):

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ELU(),
            nn.Dropout(p=0.2)
        )

    def forward(self, x):
        return self.conv(x)


class UpConcat(nn.Module):
    def __init__(self, in_feat, out_feat):
        super(UpConcat, self).__init__()

        self.up = nn.UpsamplingBilinear2d(scale_factor=2)

        # self.deconv = nn.ConvTranspose2d(in_feat, out_feat,
        #                                  kernel_size=3,
        #                                  stride=1,
        #                                  dilation=1)

        self.deconv = nn.ConvTranspose2d(in_feat,
                                         out_feat,
                                         kernel_size=2,
                                         stride=2)

    def forward(self, inputs, down_outputs):
        # TODO: Upsampling required after deconv?
        # outputs = self.up(inputs)
        outputs = self.deconv(inputs)
        out = torch.cat([down_outputs, outputs], 1)
        return out


class UpSample(nn.Module):
    def __init__(self, in_feat, out_feat):
        super(UpSample, self).__init__()

        self.up = nn.Upsample(scale_factor=2, mode='nearest')

        self.deconv = nn.ConvTranspose2d(in_feat,
                                         out_feat,
                                         kernel_size=2,
                                         stride=2)

    def forward(self, inputs, down_outputs):
        # TODO: Upsampling required after deconv?
        outputs = self.up(inputs)
        # outputs = self.deconv(inputs)
        out = torch.cat([outputs, down_outputs], 1)
        return out

In [40]:
class UNetSmall(nn.Module):
    def __init__(self, num_channels=1, num_classes=2):
        super(UNetSmall, self).__init__()
        num_feat = [32, 64, 128, 256]
        self.input_size = input_size

        self.down1 = nn.Sequential(DoubleConv(num_channels, num_feat[0]))

        self.down2 = nn.Sequential(nn.MaxPool2d(kernel_size=2),
                                   Gru(num_feat[0], gru_input_size=64),
                                   Conv(num_feat[0], num_feat[1]))

        self.down3 = nn.Sequential(nn.MaxPool2d(kernel_size=2),
                                   Gru(num_feat[1], gru_input_size=32),
                                   Conv(num_feat[1], num_feat[2]))

        self.bottom = nn.Sequential(nn.MaxPool2d(kernel_size=2),
                                   Gru(num_feat[2], gru_input_size=16),
                                   Conv(num_feat[2], num_feat[3]))

        self.up1 = UpSample(num_feat[3], num_feat[2])
        self.upconv1 = nn.Sequential(Conv3x3Small(num_feat[3] + num_feat[2], num_feat[2]),
                                     nn.BatchNorm2d(num_feat[2]))

        self.up2 = UpSample(num_feat[2], num_feat[1])
        self.upconv2 = nn.Sequential(Conv3x3Small(num_feat[2] + num_feat[1], num_feat[1]),
                                     nn.BatchNorm2d(num_feat[1]))

        self.up3 = UpSample(num_feat[1], num_feat[0])
        self.upconv3 = nn.Sequential(Conv3x3Small(num_feat[1] + num_feat[0], num_feat[0]),
                                     nn.BatchNorm2d(num_feat[0]))

        self.final = nn.Sequential(nn.Conv2d(num_feat[0],
                                             1,
                                             kernel_size=1),
                                   nn.Sigmoid())

    def forward(self, inputs, return_features=False):
        inputs = inputs.reshape(-1, 1, self.input_size, self.input_size)
        # print(inputs.data.size())
        down1_feat = self.down1(inputs)
        # print(down1_feat.size())
        down2_feat = self.down2(down1_feat)
        # print(down2_feat.size())
        down3_feat = self.down3(down2_feat)
        # print(down3_feat.size())
        bottom_feat = self.bottom(down3_feat)

        # print(bottom_feat.size())
        up1_feat = self.up1(bottom_feat, down3_feat)
        # print(up1_feat.size())
        up1_feat = self.upconv1(up1_feat)
        # print(up1_feat.size())
        up2_feat = self.up2(up1_feat, down2_feat)
        # print(up2_feat.size())
        up2_feat = self.upconv2(up2_feat)
        # print(up2_feat.size())
        up3_feat = self.up3(up2_feat, down1_feat)
        # print(up3_feat.size())
        up3_feat = self.upconv3(up3_feat)
        # print(up3_feat.size())

        if return_features:
            outputs = up3_feat
        else:
            outputs = self.final(up3_feat)

        return outputs

In [41]:
model = UNetSmall()
model = model.to(device)

In [42]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)

def dice_loss(pred, target, smooth = 1.):
    pred = pred.contiguous()
    target = target.contiguous()    
    intersection = (pred * target).sum(dim=2).sum(dim=2)
    
    loss = (1 - ((2. * intersection + smooth) / (pred.sum(dim=2).sum(dim=2) + target.sum(dim=2).sum(dim=2) + smooth)))
    
    return loss.mean()

def calc_loss(prediction, target, bce_weight=0.3):
    try:
        prediction = prediction.reshape(timesteps*batch_size, 51, 128, 128)
        target = target.reshape(timesteps*batch_size, 51, 128, 128)
    except RuntimeError:
        prediction = prediction.reshape(timesteps*15, 51, 128, 128) # last_batch = 15
        target = target.reshape(timesteps*15, 51, 128, 128)
    bce = F.binary_cross_entropy_with_logits(prediction, target)
    prediction = F.sigmoid(prediction)
    dice = lossy(prediction, target)

    loss = bce * bce_weight + dice * (1 - bce_weight)

    return loss

def lossy(x, y):
    return (((x - y)**2).sum(dim=1)).sum()/(256**2)

criterion = torch.nn.BCEWithLogitsLoss()

In [43]:
def train():
    for epoch in range(num_epochs):
        for i in range(5):
            print('*'*20)
        print('epoch: ', epoch)
        iou_metric = 0
        for i, data in enumerate(train_loader):
            input, label = data
            input = input.to(device)
            label = label.to(device)
            output = model(input)
            #loss = calc_loss(output, label)
            try:
                loss = lossy(output.reshape(timesteps*batch_size, 1, 128, 128), 
                                  label.reshape(timesteps*batch_size, 1, 128, 128))
            except RuntimeError:
                loss = lossy(output.reshape(timesteps*1, 1, 128, 128), 
                                  label.reshape(timesteps*1, 1, 128, 128))      
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            print('iter: ', i, 'avg iou_metric: ', loss.item())
        print('-'*20)
        for i in range(5):
            print('*'*20)

In [None]:
train()

********************
********************
********************
********************
********************
epoch:  0
torch.Size([4, 32, 64, 64]) torch.Size([4, 32, 64, 64])
torch.Size([4, 64, 32, 32]) torch.Size([4, 64, 32, 32])
torch.Size([4, 128, 16, 16]) torch.Size([4, 128, 16, 16])
iter:  0 avg iou_metric:  0.5659788846969604
torch.Size([4, 32, 64, 64]) torch.Size([4, 32, 64, 64])
torch.Size([4, 64, 32, 32]) torch.Size([4, 64, 32, 32])
torch.Size([4, 128, 16, 16]) torch.Size([4, 128, 16, 16])
iter:  1 avg iou_metric:  0.5509542226791382
torch.Size([4, 32, 64, 64]) torch.Size([4, 32, 64, 64])
torch.Size([4, 64, 32, 32]) torch.Size([4, 64, 32, 32])
torch.Size([4, 128, 16, 16]) torch.Size([4, 128, 16, 16])
iter:  2 avg iou_metric:  0.5391201972961426
torch.Size([4, 32, 64, 64]) torch.Size([4, 32, 64, 64])
torch.Size([4, 64, 32, 32]) torch.Size([4, 64, 32, 32])
torch.Size([4, 128, 16, 16]) torch.Size([4, 128, 16, 16])
iter:  3 avg iou_metric:  0.5311733484268188
torch.Size([4, 32, 64, 64