In [285]:
from kunlib import KUNet

In [286]:
import torch.nn as nn
import torch
import numpy as np

In [287]:
class Kernel(nn.Module):
    def __init__(self, input_dim, input_len, 
                 output_dim, output_len, params={}):
        super(Kernel, self).__init__()
        self.input_dim = input_dim
        self.input_len = input_len
        self.output_dim = output_dim
        self.output_len = output_len 
        self.params = params 

        self.is_in_encoder = False #input_len >= output_len
        self.is_in_decoder = False #input_len >= output_len

    def update_params(self, params):
        # Iterate over all fields in the class and update if they exist in params
        self.params = params 
        for key, value in params.items():
            setattr(self, key, value)  # Set the attribute from params
    
class KernelWrapper(nn.Module):
    def __init__(self, kernal, input_dim, input_len, 
                 output_dim=1, output_len=1, 
                 num_hidden_layers=1,  
                 mode="concate", params={},verbose=False):
        super(KernelWrapper, self).__init__()

        # kernal : kernal(input_dim, input_len, output_dim, output_len)
        assert (issubclass(kernal, Kernel) or isinstance(kernal, nn.Module))
        if isinstance(kernal, nn.Module) : 
          print(f"kernel {kernal} heiritated nn.Module may not adapt.")

        self.input_dim, self.input_len, self.output_dim, self.output_len = \
                        input_dim, input_len, output_dim, output_len
        self.verbose = verbose
        self.num_hidden_layers = num_hidden_layers
        self.hidden_size_list = []

        if issubclass(kernal, Kernel):
            print("kernal ", kernal, "is a Kernel")
            self.kernel = kernal(input_dim, input_len, output_dim, output_len, params=params)
        else:
            assert False, f"kernal {kernal} is not recognized"
            
        self._unet_skip_output = None
        self._unet_skip_input = None

        #self.linear_unet_skip_input = nn.Linear(self.input_dim, self.input_dim)

        self.transpose = False
        self.concat = True if mode == "concate" else False
        self.concat_mode = mode

    def f(self, x):
        if self.verbose : 
          print("---KernelWrapper.f(x) Input x.shape: ", x.shape)
        x = self.kernel(x)
        if self.verbose : 
          print("---KernelWrapper.f(x) Output x.shape: ", x.shape)
        return x

    def forward(self, x, train=False):
        if self.verbose : 
          print("---KernelWrapper.forward(x) Input x.shape:", x.shape)
          print("---train:", train)
          if self._unet_skip_input is not None:
              print("---_unet_skip_input.shape", self._unet_skip_input.shape)
          else:
              print("---_unet_skip_input", self._unet_skip_input)

        if self.transpose and self._unet_skip_input is not None:
            if self.verbose : 
                print("self.transpose and self._unet_skip_input")
                print("--x.shape", x.shape)
            if np.prod(x.shape) == np.prod(self._unet_skip_input.shape):
                if self.verbose : 
                    print( self.concat, x.shape)
                if self.concat:
                    x = torch.cat([x, self._unet_skip_input.reshape(x.shape)], dim=-1)
                    #print( "after, ", self.concat, x.shape)
                else:
                    x = x + self._unet_skip_input.reshape(x.shape)
                #print("_unet_skip_input.shape", self._unet_skip_input.shape)
            #x[len(self._unet_skip_input):] = x[len(self._unet_skip_input):] + self._unet_skip_input
        #x = x.transpose(1, 2)   # # (batch, d_model , lag) to (batch, lag, d_model)
        else:
            pass
        if self.verbose : 
            print("reshape - > x.shape", x.shape)
        x = x.reshape(-1, self.input_len, self.input_dim)

        if self.verbose : 
            print("after reshape - > x.shape", x.shape)
        x = self.f(x)

        if self.verbose : 
            print("after x = self.f(x) - > x.shape", x.shape)
        assert x.shape[1] == self.output_len and x.shape[2] == self.output_dim

        if not self.transpose:
          self._unet_skip_output = x
        return x

# Linear MLP Kernel

In [288]:
class Linear(Kernel):
    def __init__(self, input_dim, input_len, 
                 output_dim, output_len, params={}):
        super(Linear, self).__init__(input_dim, input_len, 
                 output_dim, output_len)
        # declear parameters
        self.activation = "tanh"
        self.drop_out_p = 0.05
        self.num_hidden_layers = 0
        self.update_params(params=params)

        # compute input and output size
        self.in_size = input_len*input_dim
        self.out_size = output_len*output_dim

        # prepare layers
        self.layers = []

        # check in encoder or decoder 
        self.is_in_encoder = (self.input_len >= self.output_len)
        self.is_in_decoder = not self.is_in_encoder

        # in encoder
        if self.is_in_encoder:
          gap = int((self.in_size - self.out_size) / (self.num_hidden_layers + 1))
          self.hidden_size_list = [self.in_size - i * gap for i in range(1, self.num_hidden_layers + 1)]
        
        # in decoder
        else:
          gap = int((self.out_size - self.in_size) / (self.num_hidden_layers + 1))
          self.hidden_size_list = [self.in_size + i * gap for i in range(1, self.num_hidden_layers + 1)]
        # add linear layers
        for i in range(self.num_hidden_layers):
            self.layers.append(nn.Linear(self.in_size, self.hidden_size_list[i]))

            if self.activation.lower() == "relu":
                self.layers.append(nn.ReLU())
            elif self.activation.lower() == "tanh": 
                self.layers.append(nn.Tanh())

            self.layers.append(nn.Dropout(self.drop_out_p))
            self.in_size = self.hidden_size_list[i]

        self.layers.append(nn.Linear(self.in_size, self.out_size))

        self.layers = nn.Sequential(* self.layers)

    def forward(self, x):
        x = x.reshape(-1, self.input_len * self.input_dim)
        #print("x.shape,", x.shape)
        x = self.layers(x)
        x = x.reshape(-1, self.output_len, self.output_dim)
        #print("x.shape,", x.shape)
        return x

In [289]:
kw = KernelWrapper(Linear, input_dim=128, input_len=1, 
                 output_dim=128, output_len=4, 
                 num_hidden_layers=1, 
                 mode="concate", params={"num_hidden_layers":2, 
                                         "drop_out_p":0.05, 
                                         "activation":"tanh"}, 
                                verbose=False)
print(kw)
print(kw.kernel.is_in_encoder)

kernal  <class '__main__.Linear'> is a Kernel
KernelWrapper(
  (kernel): Linear(
    (layers): Sequential(
      (0): Linear(in_features=128, out_features=256, bias=True)
      (1): Tanh()
      (2): Dropout(p=0.01, inplace=False)
      (3): Linear(in_features=256, out_features=384, bias=True)
      (4): Tanh()
      (5): Dropout(p=0.01, inplace=False)
      (6): Linear(in_features=384, out_features=512, bias=True)
    )
  )
)
False


In [290]:
x = kw(torch.rand((4,1,128)))
print(x.shape)

torch.Size([4, 4, 128])


# LSTM Kernel

In [291]:
class LSTM(Kernel):
    def __init__(self, input_dim, input_len, 
                 output_dim, output_len, params={}):
        super(LSTM, self).__init__(input_dim, input_len, 
                 output_dim, output_len)
        # declear parameters
        self.drop_out_p = 0.05
        self.num_hidden_layers = 0
        self.update_params(params=params)


        self.lstm_dim = max(input_dim, output_dim)
        self.lstm_len = max(input_len, output_len)

        # compute input and output size
        self.in_size = input_len*input_dim
        self.out_size = output_len*output_dim
        self.lstm_size = self.lstm_dim*self.lstm_len

        # prepare layers
        self.layers = []

        # check in encoder or decoder 
        self.is_in_encoder = (self.input_len >= self.output_len)
        self.is_in_decoder = not self.is_in_encoder

        # Define the LSTM and Linear layers
        self.linear_projection_in = nn.Linear(self.in_size, self.lstm_size)
        self.linear_projection_out = nn.Linear(self.lstm_size, self.out_size)

        self.lstm = nn.LSTM(self.lstm_dim, self.lstm_dim, 
                            self.num_hidden_layers, dropout=self.drop_out_p, 
                            batch_first=True)

    def forward(self, x):
        """
        Forward pass for LSTM. If we are in encoder mode, process the input sequence through LSTM
        and use the last hidden state to compute the final output using the linear layer.
        """
        #print(x.shape)
        x = x.reshape(-1, self.in_size)
        #print(x.shape)
        x = self.linear_projection_in(x)
        #print(x.shape)
        x = x.reshape(-1, self.lstm_len, self.lstm_dim)
        #print(x.shape)

        # Pass through LSTM
        x, (h_n, c_n) = self.lstm(x)  # x, lstm_out contains all hidden states, h_n is the last hidden state

        # Use the last hidden state (h_n) for linear transformation
        #print(x.shape, h_n.shape, c_n.shape)
        # Apply linear transformation
        x = x.reshape(-1, self.lstm_size)
        x = self.linear_projection_out(x)
        x = x.reshape(-1, self.output_len, self.output_dim)
        #print(x.shape)
        return x

In [292]:
kw = KernelWrapper(LSTM, input_dim=1, input_len=4, 
                 output_dim=128, output_len=1, 
                 num_hidden_layers=1, 
                 mode="add", params={"num_hidden_layers":1, 
                                         "drop_out_p":0.05,
                                         "activation":"tanh"}, 
                                verbose=False)
print(kw)
print(kw.kernel.is_in_encoder)

kernal  <class '__main__.LSTM'> is a Kernel
KernelWrapper(
  (kernel): LSTM(
    (linear_projection_in): Linear(in_features=4, out_features=512, bias=True)
    (linear_projection_out): Linear(in_features=512, out_features=128, bias=True)
    (lstm): LSTM(128, 128, batch_first=True, dropout=0.05)
  )
)
True




In [293]:
x = torch.rand(size=(13,4,1))
print(x.shape)
x = kw(x)
print(x.shape)

torch.Size([13, 4, 1])
torch.Size([13, 1, 128])


# RNN

In [294]:
class RNN(Kernel):
    def __init__(self, input_dim, input_len, 
                 output_dim, output_len, params={}):
        super(RNN, self).__init__(input_dim, input_len, 
                 output_dim, output_len)
        # declear parameters
        self.drop_out_p = 0.05
        self.num_hidden_layers = 0
        self.update_params(params=params)


        self.lstm_dim = max(input_dim, output_dim)
        self.lstm_len = max(input_len, output_len)

        # compute input and output size
        self.in_size = input_len*input_dim
        self.out_size = output_len*output_dim
        self.lstm_size = self.lstm_dim*self.lstm_len

        # prepare layers
        self.layers = []

        # check in encoder or decoder 
        self.is_in_encoder = (self.input_len >= self.output_len)
        self.is_in_decoder = not self.is_in_encoder

        # Define the LSTM and Linear layers
        self.linear_projection_in = nn.Linear(self.in_size, self.lstm_size)
        self.linear_projection_out = nn.Linear(self.lstm_size, self.out_size)

        self.lstm = nn.RNN(self.lstm_dim, self.lstm_dim, 
                            self.num_hidden_layers, dropout=self.drop_out_p, 
                            batch_first=True)

    def forward(self, x):
        """
        Forward pass for LSTM. If we are in encoder mode, process the input sequence through LSTM
        and use the last hidden state to compute the final output using the linear layer.
        """
        #print(x.shape)
        x = x.reshape(-1, self.in_size)
        #print(x.shape)
        x = self.linear_projection_in(x)
        #print(x.shape)
        x = x.reshape(-1, self.lstm_len, self.lstm_dim)
        #print(x.shape)

        # Pass through LSTM
        x, _ = self.lstm(x)  # x, lstm_out contains all hidden states, h_n is the last hidden state

        # Use the last hidden state (h_n) for linear transformation
        #print(x.shape, h_n.shape, c_n.shape)
        # Apply linear transformation
        x = x.reshape(-1, self.lstm_size)
        x = self.linear_projection_out(x)
        x = x.reshape(-1, self.output_len, self.output_dim)
        #print(x.shape)
        return x

In [295]:
kw = KernelWrapper(RNN, input_dim=1, input_len=4, 
                 output_dim=128, output_len=1, 
                 num_hidden_layers=1, 
                 mode="add", params={"num_hidden_layers":1, 
                                         "drop_out_p":0.05,
                                         "activation":"tanh"}, 
                                verbose=False)
print(kw)
print(kw.kernel.is_in_encoder)

kernal  <class '__main__.RNN'> is a Kernel
KernelWrapper(
  (kernel): RNN(
    (linear_projection_in): Linear(in_features=4, out_features=512, bias=True)
    (linear_projection_out): Linear(in_features=512, out_features=128, bias=True)
    (lstm): RNN(128, 128, batch_first=True, dropout=0.05)
  )
)
True


In [296]:
x = torch.rand(size=(13,4,1))
print(x.shape)
x = kw(x)
print(x.shape)

torch.Size([13, 4, 1])
torch.Size([13, 1, 128])


# Transformer Kernel

In [297]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        assert d_model % self.num_heads == 0, "d_model must be divisible by num_heads"
        self.depth = d_model // self.num_heads
        
        # Linear layers for queries, keys, and values
        self.Wq = nn.Linear(d_model, d_model)
        self.Wk = nn.Linear(d_model, d_model)
        self.Wv = nn.Linear(d_model, d_model)
        
        # Output linear layer
        self.fc = nn.Linear(d_model, d_model)

    def forward(self, x, mask=None):
        batch_size = x.size(0)

        # Linear projections for Q, K, V
        Q = self.Wq(x).view(batch_size, -1, self.num_heads, self.depth)
        K = self.Wk(x).view(batch_size, -1, self.num_heads, self.depth)
        V = self.Wv(x).view(batch_size, -1, self.num_heads, self.depth)

        # Permute to bring num_heads dimension to second position
        Q = Q.permute(0, 2, 1, 3)
        K = K.permute(0, 2, 1, 3)
        V = V.permute(0, 2, 1, 3)

        # Scaled dot-product attention
        scores = torch.matmul(Q, K.permute(0, 1, 3, 2)) / math.sqrt(self.depth)
        
        if mask is not None:
            mask = mask.unsqueeze(1).unsqueeze(2)  # (batch_size, 1, 1, seq_len)
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        attention = F.softmax(scores, dim=-1)

        # Weighted sum of value vectors
        out = torch.matmul(attention, V)
        out = out.permute(0, 2, 1, 3).contiguous().view(batch_size, -1, self.d_model)

        # Final linear transformation
        out = self.fc(out)
        return out

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=500):
        super(PositionalEncoding, self).__init__()
        
        if d_model % 2 == 1:
            d_model_1 = d_model + 1
        else:
            d_model_1 = d_model

        pe = torch.zeros(max_len, d_model_1)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model_1, 2).float() * (-math.log(10000.0) / d_model_1))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe_const', pe[:, :d_model])

    def forward(self, x):
        x = x + self.pe_const[:x.size(1), :].unsqueeze(0)
        return x

class AttentionBlock(nn.Module):
    def __init__(self, d_model, num_heads):
        super(AttentionBlock, self).__init__()
        self.multi_head_attention = MultiHeadAttention(d_model, num_heads)
        self.relu = nn.LeakyReLU()
        self.linear = nn.Linear(d_model, d_model)

    def forward(self, x, mask=None):
        # First residual connection
        residual = x
        x = self.multi_head_attention(x, mask)
        x = self.relu(x) + residual

        # Second residual connection
        residual = x
        x = self.linear(x)
        x = x + residual

        return x


In [298]:
class Transformer(Kernel):
    def __init__(self, input_dim, input_len, 
                 output_dim, output_len, params={}):
        super(Transformer, self).__init__(input_dim, input_len, 
                 output_dim, output_len)
        # declear parameters
        self.drop_out_p = 0.05
        self.num_hidden_layers = 0
        self.num_heads = 2
        self.update_params(params=params)


        self.lstm_dim = max(input_dim, output_dim)
        self.lstm_len = max(input_len, output_len)

        # compute input and output size
        self.in_size = input_len*input_dim
        self.out_size = output_len*output_dim
        self.lstm_size = self.lstm_dim*self.lstm_len

        # prepare layers
        self.layers = []

        # check in encoder or decoder 
        self.is_in_encoder = (self.input_len >= self.output_len)
        self.is_in_decoder = not self.is_in_encoder

        # Define the LSTM and Linear layers
        self.linear_projection_in = nn.Linear(self.in_size, self.lstm_size)
        self.linear_projection_out = nn.Linear(self.lstm_size, self.out_size)

        self.attention = nn.Sequential(*[
                                AttentionBlock(d_model=self.lstm_dim, 
                                num_heads=self.num_heads) for i in range(self.num_hidden_layers)])

    def forward(self, x):
        """
        Forward pass for LSTM. If we are in encoder mode, process the input sequence through LSTM
        and use the last hidden state to compute the final output using the linear layer.
        """
        #print(x.shape)
        x = x.reshape(-1, self.in_size)
        #print(x.shape)
        x = self.linear_projection_in(x)
        #print(x.shape)
        x = x.reshape(-1, self.lstm_len, self.lstm_dim)
        #print(x.shape)

        # Pass through LSTM
        x = self.attention(x)  # x, lstm_out contains all hidden states, h_n is the last hidden state

        # Use the last hidden state (h_n) for linear transformation
        #print(x.shape, h_n.shape, c_n.shape)
        # Apply linear transformation
        x = x.reshape(-1, self.lstm_size)
        x = self.linear_projection_out(x)
        x = x.reshape(-1, self.output_len, self.output_dim)
        #print(x.shape)
        return x

In [299]:
kw = KernelWrapper(Transformer, input_dim=128, input_len=1, 
                 output_dim=128, output_len=4, 
                 num_hidden_layers=1, 
                 mode="add", params={"num_hidden_layers":1, 
                                         "drop_out_p":0.05,
                                         "activation":"tanh"}, 
                                verbose=False)
print(kw)
print(kw.kernel.is_in_encoder)

kernal  <class '__main__.Transformer'> is a Kernel
KernelWrapper(
  (kernel): Transformer(
    (linear_projection_in): Linear(in_features=128, out_features=512, bias=True)
    (linear_projection_out): Linear(in_features=512, out_features=512, bias=True)
    (attention): Sequential(
      (0): AttentionBlock(
        (multi_head_attention): MultiHeadAttention(
          (Wq): Linear(in_features=128, out_features=128, bias=True)
          (Wk): Linear(in_features=128, out_features=128, bias=True)
          (Wv): Linear(in_features=128, out_features=128, bias=True)
          (fc): Linear(in_features=128, out_features=128, bias=True)
        )
        (relu): LeakyReLU(negative_slope=0.01)
        (linear): Linear(in_features=128, out_features=128, bias=True)
      )
    )
  )
)
False


In [300]:
x = torch.rand(size=(13,1,128))
print(x.shape)
x = kw(x)
print(x.shape)

torch.Size([13, 1, 128])
torch.Size([13, 4, 128])


# Model


In [301]:
from kunlib_v2 import KUNet

In [None]:
KUNet(input_dim=128, input_len, 
                 n_width=1, n_height=1, 
                 output_dim=1, output_len=1, 
                 hidden_dim=20,  
                 kernal_model=nn.Linear, 
                 params={} verbose=False):

In [None]:
kun = KUNet(input_dim=1, input_len=8, 
                 n_width=[1], n_height=[8,8], 
                 latent_dim=128, latent_len=1, 
                 output_dim=1, output_len=8, 
                 hidden_dim=128, num_hidden_layers=0, 
                 kernel=[Linear, LSTM, Transformer], non_linear_kernel_pos='011',
                 skip_conn=True, skip_mode="concat",
                 inverse_norm=False, mean_norm=True,
                 chanel_independent=True, residual = True, verbose=False)

print(kun)