In [1]:
import torch
from torch import nn

In [3]:
import os, sys
project_dir = os.path.join(os.getcwd(),'../..')
if project_dir not in sys.path:
    sys.path.append(project_dir)

hyspeclab_dir = os.path.join(project_dir, 'HySpecLab')
if hyspeclab_dir not in sys.path:
    sys.path.append(hyspeclab_dir)

ipdl_dir = os.path.join(project_dir, 'modules/IPDL')
if ipdl_dir not in sys.path:
    sys.path.append(ipdl_dir)

ae_dir = os.path.join(project_dir, 'modules/AutoEncoder')
if ae_dir not in sys.path:
    sys.path.append(ae_dir)

import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [173]:
from torch import Tensor

class EncoderPrior(nn.Module):
    def __init__(self, in_channels, out_channels) -> None:
        super(EncoderPrior, self).__init__()

        self.encode = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.Sigmoid(),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.Sigmoid()
        )

    def forward(self, x: Tensor):
        return self.encode(x)

class Decoder(nn.Module):
    def __init__(self, in_channels, out_channels) -> None:
        super(Decoder, self).__init__()

        self.decode = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.Sigmoid(),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.Sigmoid()
        )

    def forward(self, x: Tensor):
        return self.decode(x)

# class UnDIP(nn.module):
#     def __init__(self, in_channels : list, out_channels : list, skip_channels : list):
#         super(UnDIP, self).__init__()
#         if not(isinstance(in_channels, list)) or not(isinstance(out_channels, list)) or not(isinstance(skip_channels, list)):
#             raise ValueError('Parameters must be list')

#         if len(in_channels) != len(out_channels) != len(skip_channels):
#             raise ValueError('The parameters must contain the samme number of elements')
        
#         for 

#         self.encoder = Prior(in_channels, )

#         pass
    
#     def forward(self):
# #         pass

In [177]:
in_channels = [3, 25, 32, 64]
out_channels = [3, 25, 32, 64]
skip_channels = [2, 4, 6]

In [178]:
from collections import deque
from itertools import islice

def sliding_window_iter(iterable, size):
    '''
        Iterate through iterable using a sliding window of several elements.
        Important: It is a generator!.
        
        Creates an iterable where each element is a tuple of `size`
        consecutive elements from `iterable`, advancing by 1 element each
        time. For example:
        >>> list(sliding_window_iter([1, 2, 3, 4], 2))
        [(1, 2), (2, 3), (3, 4)]
        
        source: https://codereview.stackexchange.com/questions/239352/sliding-window-iteration-in-python
    '''
    iterable = iter(iterable)
    window = deque(islice(iterable, size), maxlen=size)
    for item in iterable:
        yield tuple(window)
        window.append(item)
    if window:  
        # needed because if iterable was already empty before the `for`,
        # then the window would be yielded twice.
        yield tuple(window)

In [180]:
encode_modules = []
for idx, (in_channel, out_channel) in enumerate(sliding_window_iter(in_channels, 2)):
    encode_modules.append(EncoderPrior(in_channel, out_channel))

decode_modules = []
for idx, (in_channel, out_channel) in enumerate(sliding_window_iter(out_channels[::-1], 2)):
    decode_modules.append(Decoder(in_channel + skip_channels[::-1][idx], out_channel))


In [181]:
encoder = nn.Sequential(*encode_modules)
decoder = nn.Sequential(*decode_modules)

In [182]:
class Level(nn.Module):
    def __init__(self, in_channels, out_channels, skip_channels = 0, deeper = None):
        super(Level, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels if deeper else in_channels

        if skip_channels:
            self.skip_ = nn.Sequential(
                nn.Conv2d(in_channels, skip_channels, kernel_size=1, stride=1),
                nn.Sigmoid()
            )

        if not deeper:
            self.f = EncoderPrior(in_channels, out_channels)
            self.decoder = Decoder(out_channels + skip_channels, in_channels)
        else:
            if not isinstance(deeper, Level):
                raise ValueError('Meh!')
            
            self.f = nn.Sequential(
                EncoderPrior(in_channels, deeper.in_channels),
                deeper
            )

            self.decoder = Decoder(deeper.out_channels + skip_channels, out_channels)

    def forward(self, x: Tensor):
        x = torch.cat([self.skip_(x), self.decoder.decode[0](self.f(x))], dim=1) if hasattr(self, 'skip_') else self.f(x)
        print(x.shape)
        return self.decoder.decode[1:](x) if hasattr(self, 'skip_') else self.decoder(x)


In [183]:
test = Level(32, 64, skip_channels=5)
test2 = Level(26, 12, deeper=test)
test3 = Level(3, 6, deeper=test2, skip_channels=12)
# test2 = Level(test, decoder[1])
# test3 = Level(test2, decoder[2])

In [184]:
x = torch.rand((1,3,128,128))
test3(x).shape

torch.Size([1, 69, 32, 32])
torch.Size([1, 32, 32, 32])
torch.Size([1, 24, 128, 128])


torch.Size([1, 6, 128, 128])

In [127]:
print(hasattr(test, 'skip_'))

True


In [185]:
test3

Level(
  (skip_): Sequential(
    (0): Conv2d(3, 12, kernel_size=(1, 1), stride=(1, 1))
    (1): Sigmoid()
  )
  (f): Sequential(
    (0): EncoderPrior(
      (encode): Sequential(
        (0): Conv2d(3, 26, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): Sigmoid()
        (2): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
        (3): Conv2d(26, 26, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): Sigmoid()
      )
    )
    (1): Level(
      (f): Sequential(
        (0): EncoderPrior(
          (encode): Sequential(
            (0): Conv2d(26, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (1): Sigmoid()
            (2): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
            (3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (4): Sigmoid()
          )
        )
        (1): Level(
          (skip_): Sequential(
            (0): Conv2d(32, 5, 

In [79]:
test2.f

Level(
  (f): EncoderPrior(
    (skip_): Sequential(
      (0): Conv2d(25, 4, kernel_size=(1, 1), stride=(1, 1))
      (1): Sigmoid()
    )
    (encode): Sequential(
      (0): Conv2d(25, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): Sigmoid()
      (2): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): Sigmoid()
    )
  )
  (decoder): Decoder(
    (decode): Sequential(
      (0): Upsample(scale_factor=2.0, mode=nearest)
      (1): Conv2d(36, 25, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (2): Sigmoid()
      (3): Conv2d(25, 25, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): Sigmoid()
    )
  )
)