In [40]:
import torch
import torch.nn as nn

In [52]:
class DenseBlock(nn.Module):
  def __init__(self, in_channels = 64, out_channels = 64):
    super(DenseBlock, self).__init__()

    self.in_channels = in_channels
    self.out_channels = out_channels

    self.block1 = self.block(in_channels = 1 * self.in_channels, out_channels = self.out_channels, use_leaky = True)
    self.block2 = self.block(in_channels = 2 * self.in_channels, out_channels = self.out_channels, use_leaky = True)
    self.block3 = self.block(in_channels = 3 * self.in_channels, out_channels = self.out_channels, use_leaky = True)
    self.block4 = self.block(in_channels = 4 * self.in_channels, out_channels = self.out_channels, use_leaky = True)
    self.block5 = self.block(in_channels = 5 * self.in_channels, out_channels = self.out_channels, use_leaky = False)

  def block(self, in_channels = None, out_channels = None, use_leaky = True):
      layers = []

      layers.append(nn.Conv2d(
          in_channels=in_channels,
          out_channels=out_channels,
          kernel_size=3,
          stride=1,
          padding=1,
          bias = True
      ))

      if use_leaky:
        layers.append(nn.LeakyReLU(inplace=True, negative_slope=0.2))

      return nn.Sequential(*layers)

  def forward(self, x):
    ouput1 = self.block1(x)
    input1 = torch.cat((ouput1, x), dim = 1) # 64 + 64 = 128

    output2 = self.block2(input1)
    input2 = torch.cat((output2, input1), dim = 1) # 64 + 128 = 192

    output3 = self.block3(input2)
    input3 = torch.cat((output3, input2), dim = 1) # 64 + 192 = 256

    output4 = self.block4(input3)
    input4 = torch.cat((output4, input3), dim = 1) # 64 + 256 = 320

    output = self.block5(input4) # 64


    return (output * 0.2) + x



In [55]:
layers = []
for _ in range(5):
  layers.append(DenseBlock(in_channels=64, out_channels=64))

model = nn.Sequential(*layers)

print(model)

Sequential(
  (0): DenseBlock(
    (block1): Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (block2): Sequential(
      (0): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (block3): Sequential(
      (0): Conv2d(192, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (block4): Sequential(
      (0): Conv2d(256, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (block5): Sequential(
      (0): Conv2d(320, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
  )
  (1): DenseBlock(
    (block1): Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (block2):

In [56]:
!pip install torchview

Collecting torchview
  Downloading torchview-0.2.6-py3-none-any.whl (25 kB)
Installing collected packages: torchview
Successfully installed torchview-0.2.6


In [57]:
from torchview import draw_graph

In [60]:
draw_graph(model = model, input_data = torch.randn(1, 64, 64, 64)).visual_graph.render(filename = "./model", format = "jpeg")

'model.jpeg'