RRDB (Residual-in-Residual Dense Block) combines three different ideas - residual connection, dense connection and residual-in-residual (local + global skip connections). This idea allows to build a deep network without accuracy reducing, because: dense connection redistributes the features in more effective 
way and residual-in-residual allows to have a stable learning even with deep layers - thats why we could have a stable network for deep features preprocessing without gradient explosion. (Could be used in Image-to-Image networks, especially GANs).

In [1]:
import torch 
from torch import nn 
import torch.nn.functional as F 

In [20]:
class ResDenseBlock(nn.Module):
    def __init__(self, nf=32, gc=32):
        super().__init__()
        
        self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1)
        self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1)
        self.conv3 = nn.Conv2d(nf + gc * 2, gc, 3, 1, 1)
        self.conv4 = nn.Conv2d(nf + gc * 3, gc, 3, 1, 1)
        self.conv5 = nn.Conv2d(nf + gc * 4, nf, 3, 1, 1)
        
    def forward(self, x):
        # print(f"x.shape: {x.shape}")
        x1 = F.leaky_relu(self.conv1(x))
        # print(f"x1.shape: {x1.shape}")
        x2 = F.leaky_relu(self.conv2(torch.cat([x, x1], 1)))
        # print(f"x2.shape: {x2.shape}")
        x3 = F.leaky_relu(self.conv3(torch.cat([x, x1, x2], 1)))
        # print(f"x3.shape: {x3.shape}")
        x4 = F.leaky_relu(self.conv4(torch.cat([x, x1, x2, x3], 1)))
        # print(f"x4.shape: {x4.shape}")
        x5 = (self.conv5(torch.cat([x, x1, x2, x3, x4], 1)))
        # print(f"x5.shape: {x5.shape}")
        return x + x5 * 0.2

In [21]:
block = ResDenseBlock()
tensor = torch.randn(1, 32, 64, 64)

out = block(tensor)
out.shape

torch.Size([1, 32, 64, 64])

And final version of RRDB will be formed from several of ResDanseBlock:

In [22]:
class RRND(nn.Module):
    def __init__(self, nf=64, gc=32):
        super().__init__()
        
        self.block1 = ResDenseBlock(nf, gc)
        self.block2 = ResDenseBlock(nf, gc)
        self.block3 = ResDenseBlock(nf, gc)
        
    def forward(self, x):
        out = self.block1(x)
        out = self.block2(out)
        out = self.block3(out)
        return x + out * 0.2

In [24]:
block = RRND()
tensor = torch.randn(1, 64, 64, 64)

out = block(tensor)
out.shape

torch.Size([1, 64, 64, 64])

RRDB is only an assistance instrument, but not an independent network, whole network with it could look like that:

In [27]:
class Net(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, nf=64, gc=32):
        super().__init__()
        
        # 3, 256, 256
        self.conv1 = nn.Conv2d(in_channels, nf, 4, 2, 1)
        # 64, 128, 128
        self.conv2 = nn.Conv2d(nf, nf, 4, 2, 1)
        # 64, 64, 64
        self.lrelu = nn.LeakyReLU(0.2, True)
        
        self.rrdb1 = RRND(nf, gc)
        self.rrdb2 = RRND(nf, gc)  
        self.rrdb3 = RRND(nf, gc)
        
        self.conv_trunk = nn.Conv2d(nf, nf, 3, 1, 1)
        
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2, mode="nearest"),
            nn.Conv2d(nf, nf, 3, 1, 1),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Upsample(scale_factor=2, mode="nearest"),
            nn.Conv2d(nf, nf, 3, 1, 1),
            nn.LeakyReLU(0.2, inplace=True)
        )
        
        
        self.de_conv1 = nn.Conv2d(nf, nf, 3, 1, 1)
        self.de_conv2 = nn.Conv2d(nf, out_channels, 3, 1, 1)
        
    def forward(self, x):
        out = self.lrelu(self.conv1(x))
        out = self.lrelu(self.conv2(out))
        
        trunk = self.rrdb1(out)
        trunk = self.rrdb2(trunk)
        trunk = self.rrdb3(trunk)
        trunk = self.conv_trunk(trunk)
        
        # print(trunk.shape)
        out += trunk     # THIS IS A GLOBAL RESIDUAL CONNECTION 
        # print(out.shape)
        
        out = self.up(out)
        
        out = self.lrelu(self.de_conv1(out))
        out = F.sigmoid(self.de_conv2(out))
        
        return out
        

In [28]:
tensor = torch.randn(1, 3, 256, 256)
net = Net()

out = net(tensor)
out.shape

torch.Size([1, 64, 64, 64])
torch.Size([1, 64, 64, 64])


torch.Size([1, 3, 256, 256])