In [1]:
import torch
from functools import partial
from torch.utils.tensorboard import SummaryWriter

In this example we will be creating a [U-net](https://towardsdatascience.com/understanding-semantic-segmentation-with-unet-6be4f42d4b47#:~:text=The%20UNET%20was%20developed%20by,The%20architecture%20contains%20two%20paths.&text=Thus%20it%20is%20an%20end,accept%20image%20of%20any%20size.) model for predicting our wall shear stress. A U-net is an example of a [convolutional neural network](https://machinelearningmastery.com/convolutional-layers-for-deep-learning-neural-networks/).

First we will create the base building block of our neural network, a simple block containing a [convolutions](https://machinelearningmastery.com/convolutional-layers-for-deep-learning-neural-networks/), [batch normalization](https://towardsdatascience.com/batch-normalization-in-neural-networks-1ac91516821c) and an ReLU [activation function](https://medium.com/the-theory-of-everything/understanding-activation-functions-in-neural-networks-9491262884e0)

In [2]:
class ConvNormAct(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, padding="same", **kwargs):
        super().__init__()
        if padding == "same":
            assert kernel_size//2 == 1
            padding = kernel_size//2
        self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, **kwargs)
        self.bnorm = torch.nn.BatchNorm2d(out_channels)
        self.activation = torch.nn.ReLU()
    
    def forward(self, x):
        return self.activation(self.bnorm(self.conv(x)))

Bellow we show a simple example of the layer we created taking in an input with 3 features and creating an output with 6 features. Finally we can pass the output through a [max pooling](https://computersciencewiki.org/index.php/Max-pooling_/_Pooling#:~:text=Max%20pooling%20is%20a%20sample,in%20the%20sub%2Dregions%20binned.) layer to reduce the size.

In [3]:
x = torch.randn(1, 3,  256, 256)
layer = ConvNormAct(3, 6, 3)
pool = torch.nn.MaxPool2d(2)
output = pool(layer(x))
print(x.shape, output.shape)

torch.Size([1, 3, 256, 256]) torch.Size([1, 6, 128, 128])


Now we need to create an upsamping layer for our data. We will use upsample convolutions, as they generally converge faster than simple transposed convolutions

In [4]:
class UpsampleConv(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, **kwargs):
        super().__init__()
        self.upsample = torch.nn.UpsamplingNearest2d(scale_factor=2)
        self.conv = ConvNormAct(in_channels, out_channels, kernel_size, **kwargs)
    def forward(self,x):
        return self.conv(self.upsample(x))

In [5]:
upsample_layer = UpsampleConv(6,3)
print(upsample_layer(output).shape)

torch.Size([1, 3, 256, 256])


Now we have the tools to create our simple u-net model. We will make a relatively shallow network and visualize it using [tensorboard](https://www.tensorflow.org/tensorboard) for pytorch

In [6]:
class UNet(torch.nn.Module):
    def __init__(self, in_channels, out_channels, base_channels=64, kernel_size=3):
        super().__init__()
        ConvWrapped = partial(ConvNormAct, kernel_size=3)
        # encoding layers
        self.conv1a = ConvWrapped(in_channels, base_channels)
        self.conv1b = ConvWrapped(base_channels, base_channels)
        self.pool_1 = torch.nn.MaxPool2d(2)
        self.conv2a = ConvWrapped(base_channels, 2*base_channels)
        self.conv2b = ConvWrapped(2*base_channels, 2*base_channels)
        self.pool_2 = torch.nn.MaxPool2d(2)
        self.conv3a = ConvWrapped(2*base_channels, 4*base_channels)
        self.conv3b = ConvWrapped(4*base_channels, 4*base_channels)
        # deconding layers
        self.upsample_1 = UpsampleConv(4*base_channels, 2*base_channels)
        self.conv4a = ConvWrapped(4*base_channels, 2*base_channels)
        self.conv4b = ConvWrapped(2*base_channels, 2*base_channels)
        self.upsample_2 = UpsampleConv(2*base_channels, base_channels)
        self.conv5a = ConvWrapped(2*base_channels, base_channels)
        self.conv5b = ConvWrapped(base_channels, base_channels)
        self.output_conv = torch.nn.Conv2d(base_channels, out_channels, kernel_size=1)
    def forward(self, x):
        x = self.conv1a(x)
        x = self.conv1b(x)
        c1 = x
        x = self.pool_1(x)
        x = self.conv2a(x)
        x = self.conv2b(x)
        c2 = x
        x = self.pool_2(x)
        x = self.conv3a(x)
        x = self.conv3b(x)
        x = self.upsample_1(x)
        x = torch.cat([x, c2], dim=1)
        x = self.conv4a(x)
        x = self.conv4b(x)
        x = self.upsample_2(x)
        x = torch.cat([x, c1], dim=1)
        x = self.conv5a(x)
        x = self.conv5b(x)
        return self.output_conv(x)



Now to visualize the created network with tensorboard

In [7]:
# create a summary writer for tensorboard
writer = SummaryWriter('runs/view_model')
# create a dummy input
x = torch.randn(1, 3,  256, 256)
# construct the model and pass the input through it
model = UNet(3, 1)
# add the graph to tensorboard and close the writer
writer.add_graph(model, x)
writer.close()

In [None]:
# load the tensorbaord extension
%load_ext tensorboard
# run tensorboard, if it does not work, we can try running the command in the terminal after moving to the required directory
%tensorboard --logdir=runs