In [1]:
import torch
import torch.nn as nn
from torch import einsum
import torch.nn.functional as F
import numpy as np
from models import DLinear
from models import PatchTST 
from models import informer

class OffsetScale(nn.Module):
    def __init__(self, dim, heads = 1):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(heads, dim))
        self.beta = nn.Parameter(torch.zeros(heads, dim))
        nn.init.normal_(self.gamma, std = 0.02)

    def forward(self, x):
        out = einsum('... d, h d -> ... h d', x, self.gamma) + self.beta
        return out.unbind(dim = -2)
    
class ReLUSquared(nn.Module):
    def forward(self, x):
        return F.relu(x) ** 2

class GAU(nn.Module):
    def __init__(
        self,
        *,
        dim,
        query_key_dim = 32, ##todo
        expansion_factor = 2.,
        add_residual = True, ##todo
        causal = False,
        dropout = 0.,
        laplace_attn_fn = False,
        rel_pos_bias = False,
        norm_klass = nn.LayerNorm
    ):
        super().__init__()
        hidden_dim = int(expansion_factor * dim) 

        self.norm = norm_klass(dim)
        self.dropout = nn.Dropout(dropout) 
        
        self.attn_fn = ReLUSquared()

        self.to_hidden = nn.Sequential(
            nn.Linear(dim, hidden_dim * 2),
            nn.SiLU()
        )

        self.to_qk = nn.Sequential(
            nn.Linear(dim, query_key_dim),
            nn.SiLU()
        )

        self.offsetscale = OffsetScale(query_key_dim, heads = 2)

        self.to_out = nn.Sequential(
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )

        self.add_residual = add_residual


    def forward(
        self,
        x,
        rel_pos_bias = None,
        mask = None
    ):
        seq_len, device = x.shape[-2], x.device

        normed_x = self.norm(x)
        v, gate = self.to_hidden(normed_x).chunk(2, dim = -1)
        
        qk = self.to_qk(normed_x)
        q, k = self.offsetscale(qk)

        sim = einsum('b i d, b j d -> b i j', q, k)

        attn = self.attn_fn(sim / seq_len)
        attn = self.dropout(attn)

        out = einsum('b i j, b j d -> b i d', attn, v)
        out = out * gate

        out = self.to_out(out)

        if self.add_residual:
            out = out + x

        return out
f = GAU(dim=128)

x = torch.rand(24, 128,42)
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.shrinkage = Shrinkage(out_channels, gap_size=(1))
        # residual function
        self.residual_function = nn.Sequential(
            nn.Conv1d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
            nn.BatchNorm1d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv1d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm1d(out_channels * BasicBlock.expansion),
            self.shrinkage
        )
        # shortcut
        self.shortcut = nn.Sequential()

        # the shortcut output dimension is not the same with residual function
        # use 1*1 convolution to match the dimension
        if stride != 1 or in_channels != BasicBlock.expansion * out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv1d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm1d(out_channels * BasicBlock.expansion)
            )

    def forward(self, x):

         return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))
        # a = self.residual_function(x),
        # b = self.shortcut(x),
        # c = a+b
        # return c


class Shrinkage(nn.Module):
    def __init__(self, channel, gap_size):
        super(Shrinkage, self).__init__()
        self.gap = nn.AdaptiveAvgPool1d(gap_size)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel),
            nn.ReLU(inplace=True),
            nn.Linear(channel, channel),
            nn.Sigmoid(),
        )

    def forward(self, x):
        x_raw = x
        x = torch.abs(x)
        x_abs = x
        x = self.gap(x)
        x = torch.flatten(x, 1)
        # average = torch.mean(x, dim=1, keepdim=True)  #CS
        average = x    #CW
        x = self.fc(x)
        x = torch.mul(average, x)
        x = x.unsqueeze(2)
        # soft thresholding
        sub = x_abs - x
        zeros = sub - sub
        n_sub = torch.max(sub, zeros)
        x = torch.mul(torch.sign(x_raw), n_sub)
        return x
shrinkage = Shrinkage(128, gap_size=(1))
x.shape, shrinkage(x).shape

  from .autonotebook import tqdm as notebook_tqdm


(torch.Size([24, 128, 42]), torch.Size([24, 128, 42]))

In [12]:
import torch
import torch.nn as nn
from torch import einsum
import torch.nn.functional as F
import numpy as np
from models import DLinear
from models import PatchTST 
from models import informer

class OffsetScale(nn.Module):
    def __init__(self, dim, heads = 1):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(heads, dim))
        self.beta = nn.Parameter(torch.zeros(heads, dim))
        nn.init.normal_(self.gamma, std = 0.02)

    def forward(self, x):
        out = einsum('... d, h d -> ... h d', x, self.gamma) + self.beta
        return out.unbind(dim = -2)
    
class ReLUSquared(nn.Module):
    def forward(self, x):
        return F.relu(x) ** 2

class Shrinkage(nn.Module):
    def __init__(self, channel, gap_size):
        super(Shrinkage, self).__init__()
        self.gap = nn.AdaptiveAvgPool1d(gap_size)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel),
            nn.ReLU(inplace=True),
            nn.Linear(channel, channel),
            nn.Sigmoid(),
        )

    def forward(self, x):
        x_raw = x
        x = torch.abs(x)
        x_abs = x
        x = self.gap(x)
        x = torch.flatten(x, 1)
        # average = torch.mean(x, dim=1, keepdim=True)  #CS
        average = x    #CW
        x = self.fc(x)
        x = torch.mul(average, x)
        x = x.unsqueeze(2)
        # soft thresholding
        sub = x_abs - x
        zeros = sub - sub
        n_sub = torch.max(sub, zeros)
        x = torch.mul(torch.sign(x_raw), n_sub)
        return x

class Refiner_block(nn.Module):
    def __init__(
        self,
        dim,
        query_key_dim = 32, ##todo
        expansion_factor = 2.,
        add_residual = True, ##todo
        causal = False,
        dropout = 0.,
        laplace_attn_fn = False,
        rel_pos_bias = False,
        norm_klass = nn.LayerNorm
    ):
        super().__init__()
        hidden_dim = int(expansion_factor * dim) 

        self.norm = norm_klass(dim)
        self.dropout = nn.Dropout(dropout) 
        
        self.attn_fn = ReLUSquared()

        self.to_hidden = nn.Sequential(
            nn.Linear(dim, hidden_dim * 2),
            nn.SiLU()
        )

        self.to_qk = nn.Sequential(
            nn.Linear(dim, query_key_dim),
            nn.SiLU()
        )

        self.offsetscale = OffsetScale(query_key_dim, heads = 2)

        self.to_out = nn.Sequential(
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )

        self.shrinkage = Shrinkage(dim, gap_size=(1))
        self.add_residual = add_residual


    def forward(
        self,
        x,
        rel_pos_bias = None,
        mask = None
    ):
        seq_len, device = x.shape[-2], x.device

        normed_x = self.norm(x)
        v, gate = self.to_hidden(normed_x).chunk(2, dim = -1)
        
        qk = self.to_qk(normed_x)
        q, k = self.offsetscale(qk)

        sim = einsum('b i d, b j d -> b i j', q, k)

        attn = self.attn_fn(sim / seq_len)
        attn = self.dropout(attn)

        out = einsum('b i j, b j d -> b i d', attn, v)
        out = out * gate

        out = self.to_out(out)

        out = self.shrinkage(out.permute(0, 2, 1)).permute(0, 2, 1) #* soft threshold before residual

        if self.add_residual:
            out = out + x

        return out

class Refiner(nn.Module): 
    def __init__(self, args):
        super(Refiner, self).__init__()
        self.args = args 
        self.refiner_block_num = 4
        self.blocks = nn.ModuleList([
            Refiner_block(128) for _ in range(self.refiner_block_num)
        ])

    def forward(self, x):
        # x: [Batch, C，P, d ]
        tmp = x.shape[0]
        x = x.reshape(-1, x.shape[-2], x.shape[-1]) 

        for i in range(self.refiner_block_num):
            x = self.blocks[i](x)
        
        x = x.reshape(tmp, -1, x.shape[-2], x.shape[-1]) 
        return x   

In [13]:
r = Refiner(1)

In [14]:
x = torch.rand(16,1,42,128)

In [15]:
r(x).shape

torch.Size([16, 1, 42, 128])

In [19]:
def scaling(x, sigma=1.1):
    # https://arxiv.org/pdf/1706.00527.pdf
    factor = np.random.normal(loc=2., scale=sigma, size=(x.shape[0], x.shape[2]))
    ai = []
    for i in range(x.shape[1]):
        xi = x[:, i, :]
        ai.append(np.multiply(xi, factor[:, :])[:, np.newaxis, :])
    return np.concatenate((ai), axis=1)

In [111]:

class e(): 
    


In [113]:
E = e(1)
e.substitude(E, torch.arange(16*96*2).reshape(16, 96, 2))

tensor([[[   0,    0],
         [   2,    2],
         [   4,    4],
         ...,
         [ 186,  187],
         [ 188,  189],
         [ 190,  191]],

        [[ 192,  192],
         [ 194,  194],
         [ 196,  196],
         ...,
         [ 378,  379],
         [ 380,  381],
         [ 382,  383]],

        [[ 384,  384],
         [ 386,  386],
         [ 388,  388],
         ...,
         [ 570,  571],
         [ 572,  573],
         [ 574,  575]],

        ...,

        [[2496, 2497],
         [2498, 2499],
         [2500, 2501],
         ...,
         [2682, 2682],
         [2684, 2684],
         [2686, 2686]],

        [[2688, 2689],
         [2690, 2691],
         [2692, 2693],
         ...,
         [2874, 2874],
         [2876, 2876],
         [2878, 2878]],

        [[2880, 2881],
         [2882, 2883],
         [2884, 2885],
         ...,
         [3066, 3066],
         [3068, 3068],
         [3070, 3070]]])

In [114]:
torch.arange(16*96*2).reshape(16, 96, 2)


tensor([[[   0,    1],
         [   2,    3],
         [   4,    5],
         ...,
         [ 186,  187],
         [ 188,  189],
         [ 190,  191]],

        [[ 192,  193],
         [ 194,  195],
         [ 196,  197],
         ...,
         [ 378,  379],
         [ 380,  381],
         [ 382,  383]],

        [[ 384,  385],
         [ 386,  387],
         [ 388,  389],
         ...,
         [ 570,  571],
         [ 572,  573],
         [ 574,  575]],

        ...,

        [[2496, 2497],
         [2498, 2499],
         [2500, 2501],
         ...,
         [2682, 2683],
         [2684, 2685],
         [2686, 2687]],

        [[2688, 2689],
         [2690, 2691],
         [2692, 2693],
         ...,
         [2874, 2875],
         [2876, 2877],
         [2878, 2879]],

        [[2880, 2881],
         [2882, 2883],
         [2884, 2885],
         ...,
         [3066, 3067],
         [3068, 3069],
         [3070, 3071]]])

In [17]:
import torch
x = torch.arange(12).reshape(2, 2, 3 )
idx = torch.tensor([[[0,1],[1,0]],[[1,2],[0,1]]])
x,x[:,:,idx]

(tensor([[[ 0,  1,  2],
          [ 3,  4,  5]],
 
         [[ 6,  7,  8],
          [ 9, 10, 11]]]),
 tensor([[[[[ 0,  1],
            [ 1,  0]],
 
           [[ 1,  2],
            [ 0,  1]]],
 
 
          [[[ 3,  4],
            [ 4,  3]],
 
           [[ 4,  5],
            [ 3,  4]]]],
 
 
 
         [[[[ 6,  7],
            [ 7,  6]],
 
           [[ 7,  8],
            [ 6,  7]]],
 
 
          [[[ 9, 10],
            [10,  9]],
 
           [[10, 11],
            [ 9, 10]]]]]))

In [26]:
x = torch.ones((3, 2, 4))
jitter(x)

tensor([[[1.0988, 0.9924, 0.9950, 0.9194],
         [1.0372, 1.1655, 0.9409, 0.9879]],

        [[1.0402, 1.0965, 0.9569, 1.0245],
         [1.0759, 0.8846, 1.1137, 1.0198]],

        [[0.9506, 1.2317, 1.0664, 1.0768],
         [0.9462, 1.1291, 0.9597, 0.9760]]], dtype=torch.float64)

In [125]:
import numpy as np
x = np.array([[4, 1, 3], [2, 9, 0]])
np.argsort(-x, axis=0)[:1,:]

array([[0, 1, 0]])