In [1]:
import torch
from torch import nn

In [2]:
import os, sys
project_dir = os.path.join(os.getcwd(),'../..')
if project_dir not in sys.path:
    sys.path.append(project_dir)

hyspeclab_dir = os.path.join(project_dir, 'HySpecLab')
if hyspeclab_dir not in sys.path:
    sys.path.append(hyspeclab_dir)

ipdl_dir = os.path.join(project_dir, 'modules/IPDL')
if ipdl_dir not in sys.path:
    sys.path.append(ipdl_dir)

ae_dir = os.path.join(project_dir, 'modules/AutoEncoder')
if ae_dir not in sys.path:
    sys.path.append(ae_dir)

import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [4]:
from torch import Tensor
from collections import OrderedDict

class Encoder(nn.Module):
    def __init__(self, in_channels, out_channels) -> None:
        super(Encoder, self).__init__()

        self.encode = nn.Sequential(OrderedDict([
                ('conv_0', nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)), 
                ('act_0', nn.Sigmoid()), 
                ('pooling', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), 
                ('conv_1', nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)), 
                ('act_1', nn.Sigmoid())
            ])
        )

    def forward(self, x: Tensor):
        return self.encode(x)

class Decoder(nn.Module):
    def __init__(self, in_channels, out_channels) -> None:
        super(Decoder, self).__init__()

        self.decode = nn.Sequential(OrderedDict([
                ('upsampling', nn.Upsample(scale_factor=2, mode='nearest')),
                ('conv_0', nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)),
                ('act_0', nn.Sigmoid()),
                ('conv_1', nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)),
                ('act_1', nn.Sigmoid())
            ])
        )

    def forward(self, x: Tensor):
        return self.decode(x)

In [13]:
from collections import OrderedDict

class Level(nn.Module):
    def __init__(self, in_channels, out_channels, skip_channels = 0, deeper = None):
        super(Level, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels if deeper else in_channels

        if skip_channels:
            self.skip_ = nn.Sequential(OrderedDict([
                    ('conv', nn.Conv2d(in_channels, skip_channels, kernel_size=1, stride=1)),
                    ('activation', nn.Sigmoid())
                ])
            )

        if not deeper:
            self.f = Encoder(in_channels, out_channels)
            self.decoder = Decoder(out_channels + skip_channels, in_channels)
        else:
            if not isinstance(deeper, Level):
                raise ValueError('Meh!')
            
            self.f = nn.Sequential(
                Encoder(in_channels, deeper.in_channels),
                deeper
            )

            print(out_channels)
            self.decoder = Decoder(deeper.out_channels + skip_channels, out_channels)

    def forward(self, x: Tensor):
        x = torch.cat([self.skip_(x), self.decoder.decode[0](self.f(x))], dim=1) if hasattr(self, 'skip_') else self.f(x)
        return self.decoder.decode[1:](x) if hasattr(self, 'skip_') else self.decoder(x)


In [14]:
class UnDIP(nn.Module):
    ''' 
        HyperSpectral Unmixing using Deep Image Prior (UnDIP)

        Parameters
        ----------
            in_channels
            out_channels
            skip_channels

        Reference
        ---------
            [1] UnDIP: Hyperspectral Unmixing Using Deep Image Prior (10.1109/TGRS.2021.3067802)
    '''
    def __init__(self, in_channels, out_channels, skip_channels, n_endmembers=4) -> None:
        ''' 
        
        '''
        super(UnDIP, self).__init__()
        if not(isinstance(in_channels, list)) or not(isinstance(out_channels, list)) or not(isinstance(skip_channels, list)):
            raise ValueError('Parameters must be list')

        if len(in_channels) != len(out_channels) != len(skip_channels):
            raise ValueError('The parameters must contain the samme number of elements')
        
        out_channels_inv = out_channels[::-1]
        skip_channels_inv = skip_channels[::-1]
        for idx, in_channel in enumerate(in_channels[::-1]):
            self.prior = Level(in_channel, out_channels_inv[idx], skip_channels=skip_channels_inv[idx], deeper=(self.prior if hasattr(self, 'prior') else None) )

        self.unmix = nn.Sequential(OrderedDict([
            ('conv_0', nn.Conv2d(out_channels[0], out_channels[0], kernel_size=3, stride=1, padding=1)),
            ('act_0', nn.LeakyReLU(negative_slope=.1)),
            ('bn_0', nn.BatchNorm2d(out_channels[0])),
            ('ee_conv', nn.Conv2d(out_channels[0], n_endmembers, kernel_size=3, stride=1, padding=1)),
            ('ee_act', nn.Softmax(dim=1)),
        ]))

    def forward(self, x : Tensor) -> Tensor:
        x = self.prior(x)
        return self.unmix(x)



In [15]:
n_bands = 116
n_endmembers = 6

in_channels = [n_bands, 12]
out_channels = [10, 4]
skip_channels = [0, 0]

model = UnDIP(in_channels, out_channels, skip_channels, n_endmembers)

10


In [21]:
test = torch.rand((6, n_bands, 16, 16))
result = model(test)

In [22]:
endmembers = torch.rand(1, n_endmembers, n_bands)
endmembers.shape

torch.Size([1, 6, 116])

In [23]:
torch.matmul(torch.transpose(endmembers, 2, 1), result.flatten(start_dim=2)).shape
# result * endmembers

torch.Size([6, 116, 256])

In [None]:
result[0].flatten(start_dim=1).shape

In [None]:
torch.transpose(endmembers, 2, 1) == endmembers

In [None]:
torch.transpose(endmembers, 2, 1).shape

In [None]:
test = result[0]
test[0, 10, 10] + test[1, 10, 10] + test[2, 10, 10] + test[3, 10, 10] + test[4, 10, 10] + test[5, 10, 10]

In [None]:
result.sum(dim=1)

In [None]:
m = nn.Softmax(dim=1)
input = torch.randn(1, 3, 4, 4)
output = m(input)

In [None]:
input.shape

In [None]:
output.sum(dim=1)

In [None]:
x = torch.rand((1,3,128,128))
test3(x).shape

In [None]:
print(hasattr(test, 'skip_'))

In [None]:
test3

In [None]:
test2.f