# Key

> #### [#] => There is a written explanation
> #### (#) => There is a code explanation

# Embedding

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat

import math

class DSW_embedding(nn.Module):
    def __init__(self, seg_len, d_model):
        super(DSW_embedding, self).__init__()
        self.seg_len = seg_len

        self.linear = nn.Linear(seg_len, d_model)

    def forward(self, x):
        batch, ts_len, ts_dim = x.shape
        print(x.shape)
        x_segment = rearrange(x, 'b (seg_num seg_len) d -> (b d seg_num) seg_len', seg_len = self.seg_len)
        print(x_segment.shape)
        x_embed = self.linear(x_segment)
        print(x_embed.shape)
        x_embed = rearrange(x_embed, '(b d seg_num) d_model -> b d seg_num d_model', b = batch, d = ts_dim)
        print(x_embed.shape)
        
        #return x_embed
embed = DSW_embedding(20,512)
data = torch.randn(16,100,20)
embed(data)

# Attention

## Attention Layer
> ### There is nothing special about these attention layers. They are the normal implementation. All of the novelty of the cross-former comes in how they are used, and the embedding and reshaping, but this is halal

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
import numpy as np

from math import sqrt

class FullAttention(nn.Module):
    '''
    The Attention operation
    '''
    def __init__(self, scale=None, attention_dropout=0.1):
        super(FullAttention, self).__init__()
        self.scale = scale
        self.dropout = nn.Dropout(attention_dropout)
        
    def forward(self, queries, keys, values):
        B, L, H, E = queries.shape
        _, S, _, D = values.shape
        scale = self.scale or 1./sqrt(E)

        scores = torch.einsum("blhe,bshe->bhls", queries, keys)
        A = self.dropout(torch.softmax(scale * scores, dim=-1))
        V = torch.einsum("bhls,bshd->blhd", A, values)
        
        return V.contiguous()

class AttentionLayer(nn.Module):
    '''
    The Multi-head Self-Attention (MSA) Layer
    '''
    def __init__(self, d_model, n_heads, d_keys=None, d_values=None, mix=True, dropout = 0.1):
        super(AttentionLayer, self).__init__()

        d_keys = d_keys or (d_model//n_heads)
        d_values = d_values or (d_model//n_heads)

        self.inner_attention = FullAttention(scale=None, attention_dropout = dropout)
        self.query_projection = nn.Linear(d_model, d_keys * n_heads)
        self.key_projection = nn.Linear(d_model, d_keys * n_heads)
        self.value_projection = nn.Linear(d_model, d_values * n_heads)
        self.out_projection = nn.Linear(d_values * n_heads, d_model)
        self.n_heads = n_heads
        self.mix = mix

    def forward(self, queries, keys, values):
        B, L, _ = queries.shape
        _, S, _ = keys.shape
        H = self.n_heads

        queries = self.query_projection(queries).view(B, L, H, -1)
        keys = self.key_projection(keys).view(B, S, H, -1)
        values = self.value_projection(values).view(B, S, H, -1)

        out = self.inner_attention(
            queries,
            keys,
            values,
        )
        if self.mix:
            out = out.transpose(2,1).contiguous()
        out = out.view(B, L, -1)

        return self.out_projection(out)


## Two Stage Attention



> ### This two stage attention is just an explicit strung out version. You wouldn't want to use this version in production
>> #### I am going to change all of this for the production code, but for knowledge sake, I'll label everything 

In [9]:
class TwoStageAttentionLayer(nn.Module):
    '''
    The Two Stage Attention (TSA) Layer
    input/output shape: [batch_size, Data_dim(D), Seg_num(L), d_model]
    '''
    def __init__(self, seg_num, factor, d_model, n_heads, d_ff = None, dropout=0.1):
        super(TwoStageAttentionLayer, self).__init__()
        d_ff = d_ff or 4*d_model
        self.time_attention = AttentionLayer(d_model, n_heads, dropout = dropout)
        self.dim_sender = AttentionLayer(d_model, n_heads, dropout = dropout)
        self.dim_receiver = AttentionLayer(d_model, n_heads, dropout = dropout)
        self.router = nn.Parameter(torch.randn(seg_num, factor, d_model))
        
        self.dropout = nn.Dropout(dropout)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.norm4 = nn.LayerNorm(d_model)

        self.MLP1 = nn.Sequential(nn.Linear(d_model, d_ff),
                                nn.GELU(),
                                nn.Linear(d_ff, d_model))
        self.MLP2 = nn.Sequential(nn.Linear(d_model, d_ff),
                                nn.GELU(),
                                nn.Linear(d_ff, d_model))

    def forward(self, x):
        #Cross Time Stage: Directly apply MSA to each dimension
        batch = x.shape[0]
            # [Be sure that seg_num is in the 1 dimension, so that is what the MHA will attend to]
        time_in = rearrange(x, 'b ts_d seg_num d_model -> (b ts_d) seg_num d_model')
            
            # [MHA]
        time_enc = self.time_attention(
            time_in, time_in, time_in
        )
            # [Add and Norm]
        dim_in = time_in + self.dropout(time_enc)
        dim_in = self.norm1(dim_in)

            # [Feed Forward Add and Norm]
        dim_in = dim_in + self.dropout(self.MLP1(dim_in))
        dim_in = self.norm2(dim_in)


        #Cross Dimension Stage: use a small set of learnable vectors to aggregate and distribute messages to build the D-to-D connection
           
            # [Rearrange so that ts_d is in the 1 dimension, so that is what the MHA will attend to]
        dim_send = rearrange(dim_in, '(b ts_d) seg_num d_model -> (b seg_num) ts_d d_model', b = batch)
            # [The router is a proxy for the queries. Instead of ]
        batch_router = repeat(self.router, 'seg_num factor d_model -> (repeat seg_num) factor d_model', repeat = batch)
        dim_buffer = self.dim_sender(batch_router, dim_send, dim_send)
        dim_receive = self.dim_receiver(dim_send, dim_buffer, dim_buffer)

            # [Add and Norm]
        dim_enc = dim_send + self.dropout(dim_receive)
        dim_enc = self.norm3(dim_enc)

            # [Feed Forward Add and Norm]
        dim_enc = dim_enc + self.dropout(self.MLP2(dim_enc))
        dim_enc = self.norm4(dim_enc)

        final_out = rearrange(dim_enc, '(b seg_num) ts_d d_model -> b ts_d seg_num d_model', b = batch)

        return final_out

### An example of how the router works to diminish the size and complexity of cross dimension encoding

In [6]:
queries_router = torch.randn(16,10,512)
keys = torch.randn(16,120,512).transpose(-1,-2)
values = torch.randn(16,120,512)

attention = torch.matmul(queries_router, keys)
attention.shape

torch.Size([16, 10, 120])

# Encoder

### General

#### The Encoder consists of three blocks, the segment merging, the scale block, and the encoder block. As opposed to a normal transformer, where each encoder block remains hidden, Crossformer encoder blocks encode information at different scales, adn thust he output of each hidden layer needs to be preserved. In the Scale Block, there is a parameter that decides how many TSA blocks you want. This means that each encoder block could consist of 1 merging layer, and then 5 tsa blocks. In the encoder module, you then use an iterator to go through each block, and perform attention at varying granularities of data

### Segment Merging

> #### [1]: This is just making sure there are the correct number of segment lengths to combine. 

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
#from modules.layers import FullAttention, AttentionLayer, TwoStageAttentionLayer
from math import ceil

class SegMerging(nn.Module):
    '''
    Segment Merging Layer.
    The adjacent `win_size' segments in each dimension will be merged into one segment to
    get representation of a coarser scale
    we set win_size = 2 in our paper
    '''
    def __init__(self, d_model, win_size, norm_layer=nn.LayerNorm):
        super().__init__()
        self.d_model = d_model
        self.win_size = win_size
        self.linear_trans = nn.Linear(win_size * d_model, d_model)
        self.norm = norm_layer(win_size * d_model)

    def forward(self, x):
        """
        x: B, ts_d, L, d_model
        """
        # [1]
        batch_size, ts_d, seg_num, d_model = x.shape
        pad_num = seg_num % self.win_size
        if pad_num != 0: 
            pad_num = self.win_size - pad_num
            # (1)
            x = torch.cat((x, x[:, :, -pad_num:, :]), dim = -2)

        # (2)
        seg_to_merge = []
        for i in range(self.win_size):
            seg_to_merge.append(x[:, :, i::self.win_size, :])
        x = torch.cat(seg_to_merge, -1)  # [B, ts_d, seg_num/win_size, win_size*d_model]
        
        x = self.norm(x)
        x = self.linear_trans(x)

        return x

#### Code Snippets

##### (1) => The [-1:] in the 2nd dimension is just saying, take the last value, and concatenate it on there. This way, our padding just provides a couple duplicate values

In [None]:
x = torch.arange(16).reshape(2,2,2,2)
print(x.shape)
print(x)
x = torch.cat((x, x[:, :, -1:, :]), dim = -2)
print(x.shape)
print(x)

##### (2) => This does the exact same thing as x.reshape(1,2,5,4)

In [8]:
import torch
seg_to_merge = []
win_size = 2
x = torch.arange(40).reshape(1,2,10,2)
o,t,th,f = x.shape
for i in range(win_size):
    seg_to_merge.append(x[:, :, i::win_size, :])
x = torch.cat(seg_to_merge, -1)
print(x)
x = x.reshape(o,t,int(th/win_size), f*win_size)
print(x)

tensor([[[[ 0,  1,  2,  3],
          [ 4,  5,  6,  7],
          [ 8,  9, 10, 11],
          [12, 13, 14, 15],
          [16, 17, 18, 19]],

         [[20, 21, 22, 23],
          [24, 25, 26, 27],
          [28, 29, 30, 31],
          [32, 33, 34, 35],
          [36, 37, 38, 39]]]])
tensor([[[[ 0,  1,  2,  3],
          [ 4,  5,  6,  7],
          [ 8,  9, 10, 11],
          [12, 13, 14, 15],
          [16, 17, 18, 19]],

         [[20, 21, 22, 23],
          [24, 25, 26, 27],
          [28, 29, 30, 31],
          [32, 33, 34, 35],
          [36, 37, 38, 39]]]])


In [38]:
mylist = [1,2,3,4,5]
torch.tensor(mylist)

tensor([1, 2, 3, 4, 5])

### Scale Block 

> #### [General]: Notice the overall structure. First, there is a segment layer, then there is 'depth' number of TSA layers
> #### [1]: this is there to run through the TSA layer before any merging occurs on the initial input.| Encoder module [2]
> #### [2]: Start out as an empty module list, but fills up with however many TSA layers we want
> #### [3]: This line does the filling of the encode layers.

In [None]:
class scale_block(nn.Module):
    '''
    We can use one segment merging layer followed by multiple TSA layers in each scale
    the parameter `depth' determines the number of TSA layers used in each scale
    We set depth = 1 in the paper

    Parameters:
    win_size: How big of segments we want to combine. If it is 2, then we will be twice as coarse each go around
    d_model: # of features
    n_heads: # of heads
    d_ff: Feed forward dimension
    depth: # of TSA layers in between each merge

    
    '''
    def __init__(self, win_size, d_model, n_heads, d_ff, depth, dropout, \
                    seg_num = 10, factor=10):
        super(scale_block, self).__init__()
        # [1]
        if (win_size > 1):
            self.merge_layer = SegMerging(d_model, win_size, nn.LayerNorm)
        else:
            self.merge_layer = None
        # [2]
        self.encode_layers = nn.ModuleList()

        # [3]
        for i in range(depth):
            self.encode_layers.append(TwoStageAttentionLayer(seg_num, factor, d_model, n_heads, \
                                                        d_ff, dropout))
    
    def forward(self, x):
        _, ts_dim, _, _ = x.shape

        if self.merge_layer is not None:
            x = self.merge_layer(x)
        
        for layer in self.encode_layers:
            x = layer(x)        
        
        return x



### Encoder Module


> #### [1]: This starts out as an empty modulelist, but fills up with however many scale_blocks we want
> #### [2]: Notice that the win_size is 1. I believe this is here so that we don't start out segmenting, as is written in the paper. After we process the initial input, we can then go through the TSA, and segment to get coarser representations as we go along. | Scale Block [1]
> #### [3]: From here on out, I think it's just business as usual. Stack up the encoder blocks like you would in any other transformer based model

In [None]:
class Encoder(nn.Module):
    '''
    The Encoder of Crossformer.
    '''
    def __init__(self, e_blocks, win_size, d_model, n_heads, d_ff, block_depth, dropout,
                in_seg_num = 10, factor=10):
        super(Encoder, self).__init__()
        # [1]
        self.encode_blocks = nn.ModuleList()
        # [2]
        self.encode_blocks.append(scale_block(1, d_model, n_heads, d_ff, block_depth, dropout,\
                                            in_seg_num, factor))
        # [3]
        for i in range(1, e_blocks):
            self.encode_blocks.append(scale_block(win_size, d_model, n_heads, d_ff, block_depth, dropout,\
                                            ceil(in_seg_num/win_size**i), factor))

    def forward(self, x):
        encode_x = []
        encode_x.append(x)
        
        for block in self.encode_blocks:
            x = block(x)
            encode_x.append(x)

        return encode_x

# Decoder

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat

class DecoderLayer(nn.Module):
    '''
    The decoder layer of Crossformer, each layer will make a prediction at its scale
    '''
    def __init__(self, seg_len, d_model, n_heads, d_ff=None, dropout=0.1, out_seg_num = 10, factor = 10):
        super(DecoderLayer, self).__init__()
        self.self_attention = TwoStageAttentionLayer(out_seg_num, factor, d_model, n_heads, \
                                d_ff, dropout)    
        self.cross_attention = AttentionLayer(d_model, n_heads, dropout = dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        self.MLP1 = nn.Sequential(nn.Linear(d_model, d_model),
                                nn.GELU(),
                                nn.Linear(d_model, d_model))
        self.linear_pred = nn.Linear(d_model, seg_len)

    def forward(self, x, cross):
        '''
        x: the output of last decoder layer
        cross: the output of the corresponding encoder layer
        '''

        batch = x.shape[0]

        # [Self MHA] | Why is there no mask?
        x = self.self_attention(x)
        x = rearrange(x, 'b ts_d out_seg_num d_model -> (b ts_d) out_seg_num d_model')
        
        cross = rearrange(cross, 'b ts_d in_seg_num d_model -> (b ts_d) in_seg_num d_model')
        
        # [MHA]
        tmp = self.cross_attention(
            x, cross, cross,
        )
        # [Add and Norm]
        x = x + self.dropout(tmp)
        y = x = self.norm1(x)

        # [Feed Forward Add and Norm]
        y = self.MLP1(y)
        dec_output = self.norm2(x+y)
        
        dec_output = rearrange(dec_output, '(b ts_d) seg_dec_num d_model -> b ts_d seg_dec_num d_model', b = batch)
        layer_predict = self.linear_pred(dec_output)
        layer_predict = rearrange(layer_predict, 'b out_d seg_num seg_len -> b (out_d seg_num) seg_len')

        return dec_output, layer_predict

In [None]:
class Decoder(nn.Module):
    '''
    The decoder of Crossformer, making the final prediction by adding up predictions at each scale
    '''
    def __init__(self, seg_len, d_layers, d_model, n_heads, d_ff, dropout,\
                router=False, out_seg_num = 10, factor=10):
        super(Decoder, self).__init__()

        self.router = router
        self.decode_layers = nn.ModuleList()
        for i in range(d_layers):
            self.decode_layers.append(DecoderLayer(seg_len, d_model, n_heads, d_ff, dropout, \
                                        out_seg_num, factor))

    def forward(self, x, cross):
        final_predict = None
        i = 0
        ts_d = x.shape[1]
        
        for layer in self.decode_layers:
            cross_enc = cross[i] # Select just one of the output layers from the encoder
            x, layer_predict = layer(x,  cross_enc)
            if final_predict is None:
                final_predict = layer_predict
            else:
                final_predict = final_predict + layer_predict
            i += 1
        
        final_predict = rearrange(final_predict, 'b (out_d seg_num) seg_len -> b (seg_num seg_len) out_d', out_d = ts_d)

        return final_predict


# Crossformer

> #### [1]: (1, data dimensions, seg_num, d_model) If you notice, these are the input dimensions that match the encoder input.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat

from modules.encoder import Encoder
from modules.decoder import Decoder
from modules.layers import FullAttention, AttentionLayer, TwoStageAttentionLayer
from modules.embed import DSW_embedding

from math import ceil

class Crossformer(nn.Module):
    def __init__(self, data_dim, in_len, out_len, seg_len, win_size = 2,
                factor=10, d_model=512, d_ff = 1024, n_heads=8, e_layers=3, 
                dropout=0.0, baseline = False, device=torch.device('cuda:0')):
        super(Crossformer, self).__init__()
        self.data_dim = data_dim
        self.in_len = in_len
        self.out_len = out_len
        self.seg_len = seg_len
        self.merge_win = win_size

        self.baseline = baseline

        self.device = device

        # The padding operation to handle invisible segment length
        self.pad_in_len = ceil(1.0 * in_len / seg_len) * seg_len
        self.pad_out_len = ceil(1.0 * out_len / seg_len) * seg_len
        self.in_len_add = self.pad_in_len - self.in_len

        # Embedding
        self.enc_value_embedding = DSW_embedding(seg_len, d_model)
        self.enc_pos_embedding = nn.Parameter(torch.randn(1, data_dim, (self.pad_in_len // seg_len), d_model))
        self.pre_norm = nn.LayerNorm(d_model)

        # Encoder
        self.encoder = Encoder(e_layers, win_size, d_model, n_heads, d_ff, block_depth = 1, \
                                    dropout = dropout,in_seg_num = (self.pad_in_len // seg_len), factor = factor)
        
        # Decoder
        # [1]
        self.dec_pos_embedding = nn.Parameter(torch.randn(1, data_dim, (self.pad_out_len // seg_len), d_model))
        self.decoder = Decoder(seg_len, e_layers + 1, d_model, n_heads, d_ff, dropout, \
                                    out_seg_num = (self.pad_out_len // seg_len), factor = factor)
        
    def forward(self, x_seq):
        if (self.baseline):
            base = x_seq.mean(dim = 1, keepdim = True)
        else:
            base = 0
        batch_size = x_seq.shape[0]
        if (self.in_len_add != 0):
            x_seq = torch.cat((x_seq[:, :1, :].expand(-1, self.in_len_add, -1), x_seq), dim = 1)

        x_seq = self.enc_value_embedding(x_seq)
        x_seq += self.enc_pos_embedding
        x_seq = self.pre_norm(x_seq)
        
        enc_out = self.encoder(x_seq)

        dec_in = repeat(self.dec_pos_embedding, 'b ts_d l d -> (repeat b) ts_d l d', repeat = batch_size)
        predict_y = self.decoder(dec_in, enc_out)


        return base + predict_y[:, :self.out_len, :]