In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as functional
import torchfields
from mmcv.ops import DeformConv2dPack
from tifffile import imread as tiff_read
from tifffile import imwrite as tiff_write
from matplotlib import pyplot as plt
import torchvision.transforms
from motion_generate import MotionGenerator
import numpy

#TODO: Residual connection may required
class FeatureExtraction(nn.Module): # return feature_lv1,feature_lv2,feature_lv3
    def __init__(self,origin_channel,internal_channel=16,lv=3) -> None:
        super().__init__()
        self.relu = nn.ReLU()
        self.level = lv
        self.feature_layer = nn.ModuleList()
        self.feature = []
        self.max_pool = nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
        self.origin_channel = origin_channel
        for i in range(self.level):
            input_channel = origin_channel if i==0 else internal_channel
            feature_block = nn.Sequential(
                                nn.Conv2d(input_channel,internal_channel,3,padding=1),
                                self.relu,
                                nn.Conv2d(internal_channel,internal_channel,3,padding=1),
                                self.relu
                            )
            self.feature_layer.append(feature_block)
    
    def forward(self,origin_image):
        for i in range(self.level):
            input = origin_image if i==0 else self.feature[i-1]
            self.feature.append(self.feature_layer[i](input))
        return self.feature
            
class Data(object):
        def __init__(self,image = torch.zeros(1,1,1,1)) -> None:
            self.image = image
            self.feature = [] 

class Algin(nn.Module):
    def __init__(self,original_channel,internal_channel = 16,groups=8,lv=3) -> None:
        super().__init__()
        self.internal_channel = internal_channel
        self.groups=groups
        self.relu = nn.ReLU()
        self.level=lv
        self.feature_extraction = FeatureExtraction(original_channel,internal_channel) #feature -> [tensor(N,C,H,W),]
        
        self.offset_bottleneck =  nn.ModuleList()
        '''
        offset bottleneck conv layer  
        input: concat(reference image feature ,unreg image feature): tensor(N,2C,H/2**lv,W/2**lv)
        output: internal offset feature :tensor(N,C,H,W)
        '''
        self.offset_generator =  nn.ModuleList()
        '''
        offset generate conv layer  
        input: concat(internal offset feature,upsambled low level offset): tensor(N,2C,H/2**lv,W/2**lv)
                or if LOWEST level:
                        internal offset feature:tensor(N,C,H,W)
        output: offset tensor(N,3*3*2,H,W)
        '''
        self.deformable_conv = nn.ModuleList()
        '''
        input: unreg image feature: tensor(N,C,H/2**lv,W/2**lv),  
               offset: tensor(N,3*3*2,H,W)
        out: internal feature: tensor(N,C,H/2**lv,W/2**lv)
        '''
        self.feature_conv = nn.ModuleList()
        '''
        input : concat(internal feature,upsambled lower level aligned feature): tensor(N,2C,H/2**lv,W/2**lv)
        out : aligned feature: tensor(N,C,H/2**lv,W/2**lv)
        '''
        self.displace_field_predict_offset_bottleneck = nn.Conv2d(2*internal_channel,internal_channel,3,padding=1)
        self.displace_field_predict_offset_generator = nn.Conv2d(2*internal_channel,36,3,padding=1)
        self.displace_field_predict_deform_conv = DeformConv2dPack(internal_channel,2,3,padding=1,deform_groups=2)
        for i in range(self.level):
            if i != self.level-1 :
                self.offset_bottleneck.append(nn.Conv2d(2*internal_channel,internal_channel,3,padding=1))
                self.offset_generator.append(nn.Conv2d(2*internal_channel,18,3,padding=1))
                self.deformable_conv.append(DeformConv2dPack(internal_channel,internal_channel,3,
                                                         padding=1))
                self.feature_conv.append(nn.Conv2d(2*internal_channel,internal_channel,3,padding=1))
    
    def forward(self,ref_image,unreg_image):
        self.ref_data = Data(ref_image)
        self.unreg_data = Data(unreg_image)
        self.aligned_data = Data()
        self.aligned_data.offset = []
        self.ref_data.feature =  self.feature_extraction(self.ref_data.image) 
        self.unreg_data.feature = self.feature_extraction(self.unreg_data.image)
        
        for i in range(self.level):
            self.aligned_data.offset.append(torch.cat((self.ref_data.feature[i],self.unreg_data.feature[i]),dim=1))
            self.aligned_data.feature.append('python list sucks')
        for i in range(self.level,-1,-1):
            self.aligned_data.offset[i]=self.relu(self.offset_bottleneck[i](self.aligned_data.offset[i]))
            if i==self.level-1:
                internal_offset_feature = self.aligned_data.offset[i]
            else:
                upsambled_lower_level_offset = functional.interpolate(self.aligned_data.offset[i+1],scale_factor=2,mode='bilinear') 
                internal_offset_feature = torch.cat((self.aligned_data.offset[i],upsambled_lower_level_offset),dim=1)
            self.aligned_data.offset[i] =self.relu(self.offset_generator[i](internal_offset_feature))
            self.aligned_data.feature[i]= self.relu(self.deformable_conv[i](self.unreg_data.feature[i],self.aligned_data.offset[i]))
            upsambled_lower_level_feature = functional.interpolate(self.aligned_data.feature[i+1],scale_factor=2,mode='bilinear')
            self.aligned_data.feature[i] = torch.cat((self.aligned_data.feature[i],upsambled_lower_level_feature),dim=1)
            self.aligned_data.feature[i] = self.relu(self.feature_conv[i](self.aligned_data.feature[i]))
        displace_field_offset = self.displace_field_predict_offset_bottleneck(
                        torch.concat((self.ref_data.feature[0],self.aligned_data.feature[0]),dim=1))
        displace_field_offset = self.displace_field_predict_offset_generator(
                                torch.cat((displace_field_offset,self.aligned_data.offset[0]),dim=1))
        displace_field = self.displace_field_predict_deform_conv(self.aligned_data.feature[0],offset =displace_field_offset)
        return displace_field
          
class DenoisedDataset(torch.utils.data.Dataset):
    def __init__(self,img_dir='data/',) -> None:
        super().__init__()
        self.img_dir = img_dir

    def __len__():
        return None

    def __getitem__(self,idx):
        imgpath = self.imgdir + f'{idx}.tif'
    

In [8]:

print('model_define as main')
#test FeatureExtraction block
f = FeatureExtraction(1)
t = torch.rand(1,1,64,64)
print(t)
t = f(t)
for fe in t:
    print(fe.shape)
g = Algin(1)

model_define as main
tensor([[[[0.0268, 0.2877, 0.0921,  ..., 0.1785, 0.3015, 0.6776],
          [0.3145, 0.3739, 0.9251,  ..., 0.4944, 0.2459, 0.9978],
          [0.3593, 0.7910, 0.8134,  ..., 0.5160, 0.6391, 0.4028],
          ...,
          [0.4550, 0.3278, 0.4619,  ..., 0.0608, 0.3335, 0.2610],
          [0.7901, 0.5826, 0.8488,  ..., 0.8658, 0.2250, 0.6202],
          [0.8745, 0.9346, 0.2246,  ..., 0.3479, 0.4916, 0.3761]]]])
torch.Size([1, 16, 64, 64])
torch.Size([1, 16, 64, 64])
torch.Size([1, 16, 64, 64])


In [9]:
t2 = torch.rand(1,1,64,64)

In [14]:
for i in range(10-1,-1,-1):
    print(i)

9
8
7
6
5
4
3
2
1
0


In [11]:
u = g(t2,t2)

IndexError: index 3 is out of range