In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from model import common
from argparse import Namespace
import random 
import math
from model.rdn import make_rdn
from model.resblock import ResBlock
import unfoldNd
from model.models import CSEUnetModel
from arb_decoder_2d import ImplicitDecoder

class DMRI_RCAN(nn.Module):
    def __init__(self,in_chans = 7,int_chans = 32):
        super().__init__()
        self.encoder = CSEUnetModel(in_chans=in_chans,out_chans=64,chans=int_chans,num_pool_layers=3,drop_prob = 0,attention_type='cSE',reduction=16)
        self.decoder = ImplicitDecoder(in_channels= 64) 
    
    def set_scale(self, scale):
        self.scale = scale

    def forward(self, inp):
        
        B,C,H,W,D = inp.shape
        # print(self.scale)
        H_hr = round(H*self.scale[0])
        W_hr = round(W*self.scale[1])
        
        size = [H_hr, W_hr]
        
        feat = self.encoder(inp)
        # print(feat.shape)
        # latent = self.latent_layer(feat)
        
        pred = self.decoder(feat,size)
        
        return pred



ModuleNotFoundError: No module named 'arb_decoder_2d'

In [7]:
from option import args

In [9]:

from model.models import CSEUnetModel

In [11]:
encoder = CSEUnetModel(in_chans=7,out_chans=64,chans=32,num_pool_layers=3,drop_prob = 0,attention_type='cSE',reduction=16)
        

In [12]:
encoder

CSEUnetModel(
  (down_sample_layers): ModuleList(
    (0): ConvBlock(in_chans=7, out_chans=32, drop_prob=0)
    (1): ConvBlock(in_chans=32, out_chans=64, drop_prob=0)
    (2): ConvBlock(in_chans=64, out_chans=128, drop_prob=0)
  )
  (conv): ConvBlock(in_chans=128, out_chans=128, drop_prob=0)
  (up_sample_layers): ModuleList(
    (0): ConvBlock(in_chans=256, out_chans=64, drop_prob=0)
    (1): ConvBlock(in_chans=128, out_chans=32, drop_prob=0)
    (2): ConvBlock(in_chans=64, out_chans=32, drop_prob=0)
  )
  (conv2): Sequential(
    (0): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1))
    (1): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1))
    (2): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
  )
)

In [13]:
buf = torch.rand((1,7,32,32))

In [22]:
tmp = encoder(buf)

In [23]:
tmp.shape

torch.Size([1, 64, 32, 32])

In [50]:
decoder = ImplicitDecoder()

In [51]:
size = (45,45)

In [52]:
decoder

ImplicitDecoder(
  (K): ModuleList(
    (0): Sequential(
      (0): Conv2d(576, 64, kernel_size=(1, 1), stride=(1, 1))
      (1): ReLU()
      (2): ResBlock(
        (convs): Sequential(
          (0): Res_Conv(
            (conv): Sequential(
              (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (1): ReLU()
            )
          )
          (1): Res_Conv(
            (conv): Sequential(
              (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (1): ReLU()
            )
          )
          (2): Res_Conv(
            (conv): Sequential(
              (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (1): ReLU()
            )
          )
          (3): Res_Conv(
            (conv): Sequential(
              (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (1): ReLU()
            )
          )
        )
        (CBAM): CBAM(
     

In [53]:
decoder(tmp,size).shape

torch.Size([1, 3, 45, 45]) torch.Size([1, 576, 45, 45])


torch.Size([1, 5, 45, 45])

In [49]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from model import common
from argparse import Namespace
import random 
import math
from model.rdn import make_rdn
from model.resblock import ResBlock

class SineAct(nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, x):
        return torch.sin(x)

def patch_norm_2d(x, kernel_size=3):
    mean = F.avg_pool2d(x, kernel_size=kernel_size, padding=kernel_size//2)
    mean_sq = F.avg_pool2d(x**2, kernel_size=kernel_size, padding=kernel_size//2)
    var = mean_sq - mean**2
    return (x-mean)/(var + 1e-6)

class ImplicitDecoder(nn.Module):
    def __init__(self, in_channels=64, hidden_dims=[64, 64, 64, 64, 64],out_chans= 5):
        super().__init__()

        last_dim_K = in_channels * 9
        
        last_dim_Q = 3

        self.K = nn.ModuleList()
        self.Q = nn.ModuleList()
        
        for hidden_dim in hidden_dims:
            self.K.append(nn.Sequential(nn.Conv2d(last_dim_K, hidden_dim, 1),
                                        nn.ReLU(),
                                        ResBlock(channels = hidden_dim, nConvLayers = 4)
                                        ))    
            self.Q.append(nn.Sequential(nn.Conv2d(last_dim_Q, hidden_dim, 1),
                                        SineAct()))
            last_dim_K = hidden_dim
            last_dim_Q = hidden_dim
            
        self.last_layer = nn.Conv2d(hidden_dims[-1], out_chans, 1)
        
        self.in_branch = nn.Sequential(nn.Conv2d(in_channels * 9, hidden_dims[-2], 1),
                            nn.ReLU(),
                            nn.Conv2d(hidden_dims[-2],hidden_dims[-1], 1),
                            nn.ReLU(),
                            nn.Conv2d(hidden_dims[-1],out_chans, 1),
                            nn.ReLU())
        
    def _make_pos_encoding(self, x, size): 
        B, C, H, W = x.shape
        H_up, W_up = size
       
        h_idx = -1 + 1/H + 2/H * torch.arange(H, device=x.device).float()
        w_idx = -1 + 1/W + 2/W * torch.arange(W, device=x.device).float()
        in_grid = torch.stack(torch.meshgrid(h_idx, w_idx), dim=0)

        h_idx_up = -1 + 1/H_up + 2/H_up * torch.arange(H_up, device=x.device).float()
        w_idx_up = -1 + 1/W_up + 2/W_up * torch.arange(W_up, device=x.device).float()
        up_grid = torch.stack(torch.meshgrid(h_idx_up, w_idx_up), dim=0)
        
        rel_grid = (up_grid - F.interpolate(in_grid.unsqueeze(0), size=(H_up, W_up), mode='nearest-exact'))
        rel_grid[:,0,:,:] *= H
        rel_grid[:,1,:,:] *= W

        return rel_grid.contiguous().detach()

    def step(self, x, syn_inp):
        
        q = syn_inp
        
        k = x
        
        for i in range(len(self.K)):
            
            k = self.K[i](k)
            q = k*self.Q[i](q)
            
        q = self.last_layer(q)
        
        return q + self.in_branch(x)


    def forward(self, x, size):
        B, C, H_in, W_in = x.shape
        
        rel_coord = (self._make_pos_encoding(x, size).expand(B, -1, *size))
        
        ratio = (x.new_tensor([math.sqrt((H_in*W_in)/(size[0]*size[1]))]).view(1, -1, 1, 1).expand(B, -1, *size))
        
        syn_inp = torch.cat([rel_coord, ratio], dim=1)
        
        x = F.interpolate(F.unfold(x, 3, padding=1).view(B, C*9, H_in, W_in), size=syn_inp.shape[-2:], mode='bilinear')
        
        print(syn_inp.shape,x.shape)
        pred = self.step(x, syn_inp)
        return pred