In [1]:
import torch
import torch.nn as nn

In [16]:
class DenseBlock(nn.Module):
  def __init__(self, in_channels = 64, out_channels = 64):
    super(DenseBlock, self).__init__()

    self.in_channels = in_channels
    self.out_channels = out_channels
    self.kernel_size = 3
    self.stride = 1
    self.padding = 1

    self.block1 = nn.Sequential(
        nn.Conv2d(in_channels=1*self.in_channels, out_channels=self.out_channels, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding),
        nn.LeakyReLU(negative_slope=0.2)
    )

    self.block2 = nn.Sequential(
        nn.Conv2d(in_channels=2*self.in_channels, out_channels=self.out_channels, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding),
        nn.LeakyReLU(negative_slope=0.2)
    )

    self.block3 = nn.Sequential(
        nn.Conv2d(in_channels=3*self.in_channels, out_channels=self.out_channels, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding),
        nn.LeakyReLU(negative_slope=0.2)
    )

    self.block4 = nn.Sequential(
        nn.Conv2d(in_channels=4*self.in_channels, out_channels=self.out_channels, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding),
        nn.LeakyReLU(negative_slope=0.2)
    )

    self.block5 = nn.Sequential(
        nn.Conv2d(in_channels=5*self.in_channels, out_channels=self.out_channels, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding),
        nn.LeakyReLU(negative_slope=0.2)
    )

  def forward(self, x):
    x1 = self.block1(x)
    x1= torch.cat((x, x), dim = 1)

    x2 = self.block2(x1)
    x2 = torch.cat((x1, x2), dim = 1)

    x3 = self.block3(x2)
    x3 = torch.cat((x2, x3), dim = 1)

    x4 = self.block4(x3)
    x4 = torch.cat((x3, x4), dim = 1)

    return self.block5(x4)

In [17]:
netDense = DenseBlock(in_channels=64, out_channels=64)

netDense(torch.randn((1, 64, 64, 64))).size()

torch.Size([1, 64, 64, 64])

In [35]:
class ResidualInResidual(nn.Module):
  def __init__(self, in_channels = 64, out_channels = 64, alpha = 0.85):
    super(ResidualInResidual, self).__init__()

    self.in_channels = in_channels
    self.out_channels = out_channels
    self.alpha = 0.85

    self.layers = []

    self.dense_block = nn.Sequential(*[DenseBlock(in_channels=self.in_channels, out_channels=self.out_channels) for _ in range(3)])

  def forward(self, x):
    for block in self.dense_block:
      x = torch.mul(block(x), self.alpha)
    return x


In [36]:
netResidual = ResidualInResidual(in_channels=64, out_channels=64)

netResidual(torch.randn((1, 64, 64, 64))).size()

torch.Size([1, 64, 64, 64])

In [37]:
class ResidualInResidual(nn.Module):
  def __init__(self, in_channels = 64, out_channels = 64, alpha = 0.85):
    super(ResidualInResidual, self).__init__()

    self.in_channels = in_channels
    self.out_channels = out_channels
    self.alpha = 0.85

    self.layers = []

    self.dense_block1 = DenseBlock(in_channels=self.in_channels, out_channels=self.out_channels)
    self.dense_block2 = DenseBlock(in_channels=self.in_channels, out_channels=self.out_channels)
    self.dense_block3 = DenseBlock(in_channels=self.in_channels, out_channels=self.out_channels)

  def forward(self, x):
    block1 = self.dense_block1(x)
    block1 = torch.mul(block1, self.alpha)


    block2 = self.dense_block1(block1)
    block2 = torch.mul(block2, self.alpha)

    block3 = self.dense_block1(block2)
    block3 = torch.mul(block3, self.alpha)

    return block3

In [39]:
netResidual = ResidualInResidual(in_channels=64, out_channels=64)
netResidual(torch.randn((1, 64, 64, 64))).size()

torch.Size([1, 64, 64, 64])