In [1]:
import os
import datetime
import numpy as np
import pandas as pd
import cv2
import torch
import torch.nn as nn
from torchmetrics import Dice
from torchmetrics import MeanSquaredError
from torchvision.utils import save_image
import torch.nn.functional as functional
import matplotlib.pyplot as plt
from IPython import display

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device('cpu')
split = pd.read_csv("split.csv")
learning_rate = 0.01
epoch_num = 51

data_train = []
sem_train = []
data_dev = []
sem_dev = []
data_test = []
sem_test = []

for _, row in split.iterrows():
    img = cv2.resize(cv2.imread(row['img_path']), (2**8, 2**8), interpolation = cv2.INTER_AREA)
    sem = cv2.resize(np.delete(cv2.imread(row['sem_path']), (1, 2), 2), (2**8, 2**8), interpolation = cv2.INTER_AREA)
    def bool(x):
        return (x > 0)
    sem = np.vectorize(bool)(sem)
    if row['split'] == 'train':
        data_train.append(img)
        sem_train.append(sem)
    if row['split'] == 'dev':
        data_dev.append(img)
        sem_dev.append(sem)
    if row['split'] == 'test':
        data_test.append(img)
        sem_test.append(sem)

data_train = torch.FloatTensor(np.array(data_train)).to(device)
sem_train = torch.FloatTensor(np.array(sem_train)).to(device)
data_dev = torch.FloatTensor(np.array(data_dev)).to(device)
sem_dev = torch.FloatTensor(np.array(sem_dev)).to(device)
data_test = torch.FloatTensor(np.array(data_test)).to(device)
sem_test = torch.FloatTensor(np.array(sem_test)).to(device)

In [3]:
class Conv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Conv, self).__init__()
        self.conv_1 = nn.Conv2d(in_channels, out_channels, kernel_size=2)
        self.conv_2 = nn.Conv2d(out_channels, out_channels, kernel_size=2)
        self.batchnorm = nn.BatchNorm2d(out_channels)
        self.activation = nn.ReLU()

    def forward(self, x):
        x = self.activation(self.batchnorm(self.conv_1(x)))
        x = self.activation(self.batchnorm(self.conv_2(x)))
        return x

class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.conv_1 = Conv(3, 2**6)
        self.conv_2 = Conv(2**6, 2**7)
        self.conv_3 = Conv(2**7, 2**8)
        self.conv_4 = Conv(2**8, 2**9)
        self.conv_5 = Conv(2**9, 2**10)
        self.pool = nn.MaxPool2d(2)

    def forward(self, x):
        x1 = self.conv_1(x)
        x = self.pool(x1)
        x2 = self.conv_2(x)
        x = self.pool(x2)
        x3 = self.conv_3(x)
        x = self.pool(x3)
        x4 = self.conv_4(x)
        x = self.pool(x4)
        return (self.conv_5(x), x1, x2, x3, x4)

class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.out_shape = (2**8, 2**8)
        self.back_1 = nn.ConvTranspose2d(2**10, 2**9, kernel_size=2, stride=2)
        self.conv_1 = Conv(2**10, 2**9)
        self.back_2 = nn.ConvTranspose2d(2**9, 2**8, kernel_size=2, stride=2)
        self.conv_2 = Conv(2**9, 2**8)
        self.back_3 = nn.ConvTranspose2d(2**8, 2**7, kernel_size=2, stride=2)
        self.conv_3 = Conv(2**8, 2**7)
        self.back_4 = nn.ConvTranspose2d(2**7, 2**6, kernel_size=2, stride=2)
        self.conv_4 = Conv(2**7, 2**6)
        self.conv_5 = nn.Conv2d(2**6, 1, kernel_size=1)

    def pad(self, tensor, shape):
        dW = shape[-1] - tensor.shape[-1]
        dH = shape[-2] - tensor.shape[-2]
        val_w = dW // 2
        val_h = dH // 2
        return functional.pad(tensor, [val_w, dW - val_w, val_h, dH - val_h])

    def forward(self, x, x1, x2, x3, x4):
        x = torch.cat([x4, self.pad(self.back_1(x), x4.shape)], dim=1)
        x = torch.cat([x3, self.pad(self.back_2(self.conv_1(x)), x3.shape)], dim=1)
        x = torch.cat([x2, self.pad(self.back_3(self.conv_2(x)), x2.shape)], dim=1)
        x = torch.cat([x1, self.pad(self.back_4(self.conv_3(x)), x1.shape)], dim=1)
        return torch.sigmoid(self.pad(self.conv_5(self.conv_4(x)), self.out_shape))

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()

    def forward(self, x):
        x = torch.permute(x, (0, 3, 1, 2))
        x, x1, x2, x3, x4 = self.encoder(x)
        x = self.decoder(x, x1, x2, x3, x4)
        return x

In [4]:
model = Model().to(device)
opt = torch.optim.Adam(model.parameters(),lr=learning_rate)
lossfunction = nn.BCELoss()


In [7]:
for epoch_num in range(epoch_num):
    idx = np.random.randint(len(data_train), size=4)
    data_batch = data_train[idx]
    y_batch = sem_train[idx]
    
    predictions = torch.squeeze(model(data_batch), dim=1)
    
    loss = lossfunction(predictions, y_batch)
    loss.backward()
    opt.step()
    opt.zero_grad()
    
    loss.data.numpy()
    if epoch_num % 10 == 0:
        print('epoch =', epoch_num, 'loss = ', loss.data.numpy())

epoch = 0 loss =  0.4110845
epoch = 10 loss =  0.07000518
epoch = 20 loss =  0.084308356
epoch = 30 loss =  0.048253164
epoch = 40 loss =  0.367396


In [10]:
model.eval()
test = model(data_test)
test = test.detach().apply_(lambda x: (0 if x <= 0.5 else 1))
test = torch.squeeze(test, dim=1)

# print("(dev) loss value = ", lossfunction(torch.squeeze(test, dim=1), sem_test).data)
print("DiffFgMSE value = ", MeanSquaredError()(test, sem_test))
print("DiffFgDICE value = ", Dice()(test.type(torch.LongTensor), sem_test.type(torch.LongTensor)))

for (index, image) in enumerate(test):
    file_path = 'results/result_{num}.png'.format(num=index)
    if os.path.exists(file_path):
        os.remove(file_path)
    save_image(image * 255, file_path)
    if (index % 3 == 2):
        print('image', file_path, 'successfuly saved ', datetime.datetime.now())

DiffFgMSE value =  tensor(0.0249)
DiffFgDICE value =  tensor(0.9751)
image results/result_2.png successfuly saved  2023-07-30 16:35:45.134532
image results/result_5.png successfuly saved  2023-07-30 16:35:45.153313
image results/result_8.png successfuly saved  2023-07-30 16:35:45.171563
image results/result_11.png successfuly saved  2023-07-30 16:35:45.187328
image results/result_14.png successfuly saved  2023-07-30 16:35:45.202146
image results/result_17.png successfuly saved  2023-07-30 16:35:45.220131
image results/result_20.png successfuly saved  2023-07-30 16:35:45.236726
image results/result_23.png successfuly saved  2023-07-30 16:35:45.253441
image results/result_26.png successfuly saved  2023-07-30 16:35:45.271508
image results/result_29.png successfuly saved  2023-07-30 16:35:45.290904
image results/result_32.png successfuly saved  2023-07-30 16:35:45.309110
image results/result_35.png successfuly saved  2023-07-30 16:35:45.325247
image results/result_38.png successfuly saved 