In [6]:
import torch

In [7]:
tmp = torch.rand((1,7,32,32,4))

In [8]:

from model.rdn import make_rdn

In [9]:
encoder = make_rdn(in_chans=7,growth = 32)

In [10]:
encoder

RDN(
  (SFENet1): Conv3d(7, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (SFENet2): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (RDBs): ModuleList(
    (0-4): 5 x RDB(
      (convs): Sequential(
        (0): RDB_Conv(
          (conv): Sequential(
            (0): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
            (1): ReLU()
          )
        )
        (1): RDB_Conv(
          (conv): Sequential(
            (0): Conv3d(96, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
            (1): ReLU()
          )
        )
        (2): RDB_Conv(
          (conv): Sequential(
            (0): Conv3d(160, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
            (1): ReLU()
          )
        )
        (3): RDB_Conv(
          (conv): Sequential(
            (0): Conv3d(224, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
            (1): ReLU()
       

In [11]:
tmp2 = encoder(tmp)

In [12]:
tmp2.shape

torch.Size([1, 32, 32, 32, 4])

In [31]:
decoder = ImplicitDecoder(in_channels=32)

In [33]:
decoder(tmp2,(45,45,6)).shape

torch.Size([1, 4, 45, 45, 6]) torch.Size([1, 864, 45, 45, 6])


torch.Size([1, 5, 45, 45, 6])

In [30]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from model import common
from argparse import Namespace
import random 
import math
from model.rdn import make_rdn
from model.resblock import ResBlock_3d
import unfoldNd

class SineAct(nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, x):
        return torch.sin(x)

class ImplicitDecoder(nn.Module):
    def __init__(self, in_channels=64, hidden_dims=[64, 64, 64, 64, 64],out_chans= 5):
        super().__init__()

        last_dim_K = in_channels * 27
        
        last_dim_Q = 4

        self.K = nn.ModuleList()
        self.Q = nn.ModuleList()
        
        for hidden_dim in hidden_dims:
            self.K.append(nn.Sequential(nn.Conv3d(last_dim_K, hidden_dim, 1),
                                        nn.ReLU(),
                                        ResBlock_3d(channels = hidden_dim, nConvLayers = 4)
                                        ))    
            self.Q.append(nn.Sequential(nn.Conv3d(last_dim_Q, hidden_dim, 1),
                                        SineAct()))
            last_dim_K = hidden_dim
            last_dim_Q = hidden_dim
            
        self.last_layer = nn.Conv3d(hidden_dims[-1], out_chans, 1)
        
        self.in_branch = nn.Sequential(nn.Conv3d(in_channels * 27, hidden_dims[-2], 1),
                            nn.ReLU(),
                            nn.Conv3d(hidden_dims[-2],hidden_dims[-1], 1),
                            nn.ReLU(),
                            nn.Conv3d(hidden_dims[-1],out_chans, 1),
                            nn.ReLU())
        
    def _make_pos_encoding(self, x, size): 
        B, C, H, W, D = x.shape
        H_up, W_up, D_up = size
       
        h_idx = -1 + 1/H + 2/H * torch.arange(H, device=x.device).float()
        w_idx = -1 + 1/W + 2/W * torch.arange(W, device=x.device).float()
        d_idx = -1 + 1/D + 2/D * torch.arange(D, device=x.device).float()
        in_grid = torch.stack(torch.meshgrid(h_idx, w_idx,d_idx), dim=0)

        h_idx_up = -1 + 1/H_up + 2/H_up * torch.arange(H_up, device=x.device).float()
        w_idx_up = -1 + 1/W_up + 2/W_up * torch.arange(W_up, device=x.device).float()
        d_idx_up = -1 + 1/D_up + 2/D_up * torch.arange(D_up, device=x.device).float()
        up_grid = torch.stack(torch.meshgrid(h_idx_up, w_idx_up,d_idx_up), dim=0)
        
        rel_grid = (up_grid - F.interpolate(in_grid.unsqueeze(0), size=(H_up, W_up,D_up), mode='nearest-exact'))
        rel_grid[:,0,:,:] *= H
        rel_grid[:,1,:,:] *= W
        rel_grid[:,2,:,:] *= D

        return rel_grid.contiguous().detach()

    def step(self, x, syn_inp):
        
        q = syn_inp
        
        k = x
        
        for i in range(len(self.K)):
            
            k = self.K[i](k)
            q = k*self.Q[i](q)
            
        q = self.last_layer(q)
        
        return q + self.in_branch(x)


    def forward(self, x, size):
        B, C, H_in, W_in,D_in = x.shape
        
        rel_coord = (self._make_pos_encoding(x, size).expand(B, -1, *size))
        
        ratio = (x.new_tensor([math.sqrt((H_in*W_in*D_in)/(size[0]*size[1]*size[2]))]).view(1, -1, 1, 1).expand(B, -1, *size))
        
        syn_inp = torch.cat([rel_coord, ratio], dim=1)
        
        x = F.interpolate(unfoldNd.unfoldNd(x, 3, padding=1).view(B, C*27, H_in, W_in,D_in), size=ratio.shape[-3:],mode = "trilinear")
        
        print(syn_inp.shape,x.shape)
        pred = self.step(x, syn_inp)
        return pred