In [1]:
from model import rdn
import argparse, os
import numpy as np
import matplotlib.pyplot as plt
import torch

In [11]:
# Residual Dense Network for Image Super-Resolution
# https://arxiv.org/abs/1802.08797

from model import common
from argparse import Namespace

import torch
import torch.nn as nn


def make_rdn(inchannel=6, RDNkSize=3, growth = 16, RDNconfig='C'):
    args = Namespace()
    args.G0 = growth
    args.RDNkSize = RDNkSize
    args.RDNconfig = RDNconfig

    args.n_colors = inchannel
    return RDN(args)

class RDB_Conv(nn.Module):
    def __init__(self, inChannels, growRate, kSize=3):
        super(RDB_Conv, self).__init__()
        Cin = inChannels
        G  = growRate
        self.conv = nn.Sequential(*[
            nn.Conv2d(Cin, G, kSize, padding=(kSize-1)//2, stride=1),
            nn.ReLU()
        ])

    def forward(self, x):
        out = self.conv(x)
        return torch.cat((x, out), 1)

class RDB(nn.Module):
    def __init__(self, growRate0, growRate, nConvLayers, kSize=3):
        super(RDB, self).__init__()
        G0 = growRate0
        G  = growRate
        C  = nConvLayers
        
        convs = []
        for c in range(C):
            convs.append(RDB_Conv(G0 + c*G, G))
        self.convs = nn.Sequential(*convs)
        
        # Local Feature Fusion
        self.LFF = nn.Conv2d(G0 + C*G, G0, 1, padding=0, stride=1)

    def forward(self, x):
        return self.LFF(self.convs(x)) + x

class RDN(nn.Module):
    def __init__(self, args):
        super(RDN, self).__init__()
        G0 = args.G0
        kSize = args.RDNkSize

        # number of RDB blocks, conv layers, out channels
        self.D, C, G = {
            'C': (3, 5, 32),
        }[args.RDNconfig]

        # Shallow feature extraction net
        self.SFENet1 = nn.Conv3d(args.n_colors, G0, kSize, padding=(kSize-1)//2, stride=1)
        self.SFENet2 = nn.Conv3d(G0, G0, kSize, padding=(kSize-1)//2, stride=1)

        # Redidual dense blocks and dense feature fusion
        self.RDBs = nn.ModuleList()
        for i in range(self.D):
            self.RDBs.append(
                RDB(growRate0 = G0, growRate = G, nConvLayers = C)
            )

        # Global Feature Fusion
        self.GFF = nn.Sequential(*[
            nn.Conv2d(self.D * G0, G0, 1, padding=0, stride=1),
            nn.Conv2d(G0, G0, kSize, padding=(kSize-1)//2, stride=1)
        ])

        self.out_dim = G0

    def forward(self, x):
        f__1 = self.SFENet1(x)
        x  = self.SFENet2(f__1)
        
        RDBs_out = []
        for i in range(self.D):
            x = self.RDBs[i](x)
            RDBs_out.append(x)

        x = self.GFF(torch.cat(RDBs_out,1))
        x += f__1

        return x
        # return self.UPNet(x)

In [12]:
enc= make_rdn()

In [13]:
enc

RDN(
  (SFENet1): Conv3d(6, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (SFENet2): Conv3d(16, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (RDBs): ModuleList(
    (0-2): 3 x RDB(
      (convs): Sequential(
        (0): RDB_Conv(
          (conv): Sequential(
            (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (1): ReLU()
          )
        )
        (1): RDB_Conv(
          (conv): Sequential(
            (0): Conv2d(48, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (1): ReLU()
          )
        )
        (2): RDB_Conv(
          (conv): Sequential(
            (0): Conv2d(80, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (1): ReLU()
          )
        )
        (3): RDB_Conv(
          (conv): Sequential(
            (0): Conv2d(112, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (1): ReLU()
          )
        )
        (4): RDB_Conv(

In [55]:
temp = torch.randn(8, 6, 16,16,16)
f_1 = enc.SFENet1(temp)
f_1 = f_1.reshape(8,16,64,64)

In [41]:
enc_out = enc.SFENet2(enc.SFENet1(temp))

In [42]:
enc_out.shape

torch.Size([8, 16, 16, 16, 16])

In [43]:
enc_out_reshape = enc_out.reshape(8,16,64,64)

In [44]:
enc_out_reshape.shape

torch.Size([8, 16, 64, 64])

In [45]:
D, C, G = {
'C': (3, 5, 32),
}['C']
RDBs = nn.ModuleList()
G0 = 16
kSize = 3
for i in range(5):
    RDBs.append(RDB(growRate0 = 16, growRate = G, nConvLayers = C))

In [46]:
dec_temp = enc_out_reshape

In [47]:
RDBs_out = []
for i in range(D):
    dec_temp = RDBs[i](dec_temp)
    RDBs_out.append(dec_temp)

In [48]:
len(RDBs_out)

3

In [49]:
for i in range(3):
    print(RDBs_out[i].shape)

torch.Size([8, 16, 64, 64])
torch.Size([8, 16, 64, 64])
torch.Size([8, 16, 64, 64])


In [34]:
torch.cat(RDBs_out,1).shape

torch.Size([8, 48, 64, 64])

In [50]:
GFF = nn.Sequential(*[
nn.Conv2d(D * G0, G0, 1, padding=0, stride=1),
nn.Conv2d(G0, G0, kSize, padding=(kSize-1)//2, stride=1)
])

In [56]:
x = GFF(torch.cat(RDBs_out,1))
x += f_1

In [58]:
x.shape

torch.Size([8, 16, 64, 64])