This is a neural network based on the modified U-Net architecture found in the DeepHarmony paper (as pictured below). In addition, it features batch normalization layers integrated within the network and a compound loss function made up of MS-SSIM and L1. 


Importing necessary modules

In [1]:
import torch
from torch import nn 
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt


Definition of the Network Class, including loss function definition and convolution definitions

In [8]:
class Unet(nn.Module):
    def __init__(self, input_folder, output_folder):
        super().__init__()
        self.input_folder = input_folder
        self.output_folder = output_folder
        self.input_size = (512, 512, 1)

    def convolution(self, input_image, in_channels, out_channels, kernel_size, stride):
        conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride)(input_image)
        bn = nn.BatchNorm2d(num_features=1, eps=0.0001)(conv)
        return bn

    def down_convolution(self, inp, n_filters):
        conv = nn.Conv2D(n_filters, (4, 4), activation="relu", padding="same", strides=(2,2), kernel_initializer="he_normal")(inp)
        bn = bn = nn.BatchNorm2d(num_features=1, eps=0.0001)(conv)
        return bn

    def up_convolution(self, inp, n_filters, conv_features):
        deconvolution = nn.Conv2d(n_filters, (4, 4), activation = 'relu', padding = 'same', strides=(0.5, 0.5), kernel_initializer = 'he_normal')(inp)
        concatenation = torch.cat([conv_features, deconvolution], axis=3)
        return concatenation

Importing Data, running preprocessing programs, and Splitting into training and validation

In [9]:
preprocessing = Preprocessing("./data/modified/", "./data/preprocessed/")
training = []
testing = []
validation = []
input_files = os.listdir(preprocessing.input_folder)

for i in range(179):
    training.append(input_files[i])

for i in range(179, 223):
    testing.append(input_files[i])

for i in range (223, 267):
    validation.append(input_files[i])


Training Model