In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from thop import profile
from thop import clever_format
import time
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
batch_size = 16
seq_len_t = 128
channel_d = 64
channel_in = channel_d
channel_out = channel_d
f_dt = torch.rand(batch_size, channel_d, seq_len_t) # BHW
weight = torch.rand(channel_out, channel_in, 3)
out =  F.conv1d(f_dt, weight, bias=None, stride=1, padding=1, dilation=1, groups=1)
print(out.shape)

torch.Size([16, 64, 128])


In [79]:
inputs = torch.randn(1, 3, 3)
filters = torch.randn(3, 3, 2)
out = F.conv1d(inputs, filters)
print(out.shape)

torch.Size([1, 3, 2])


In [54]:
filters

tensor([[[ 1.4881, -0.2732],
         [-1.0998,  0.6145],
         [-0.3854, -0.4885]],

        [[-1.0726, -0.7862],
         [ 0.2072, -0.1803],
         [ 2.6887,  1.0425]],

        [[ 1.8110,  0.2561],
         [-1.2803,  0.7140],
         [ 0.6212, -0.1925]]])

In [55]:
sum(sum(filters[0]*inputs[0,:,:2]))

tensor(0.9133)

In [56]:
sum(sum(filters[0]*inputs[0,:,1:]))

tensor(0.6544)

In [57]:
out

tensor([[[ 0.9133,  0.6544],
         [-5.5761,  0.7350],
         [-1.8866,  1.3048]]])

In [58]:
h = channel_d
channel_in = channel_d
channel_out = channel_d # number of kernels
weight = torch.rand(channel_out, np.int(channel_in/h), 3)
out =  F.conv1d(f_dt, weight,padding =1, groups=h)
print(out.shape)

torch.Size([16, 64, 128])


In [124]:
# depthwise conv: the input channel, filter_channel, groups, and output channels must follows
# filter_channels = input channels/groups
# output channles =  k * input channels/filter_channels=groups
# where k behaves like multi-head
inputs = torch.randn(1, 6, 3)
filters = torch.randn(6, 2, 2)
out = F.conv1d(inputs, filters, groups=3)
print(out.shape)

torch.Size([1, 6, 2])


In [125]:
filters

tensor([[[ 0.1123, -0.0244],
         [ 0.6030, -1.5401]],

        [[-1.8002,  1.1385],
         [ 0.6183,  1.6528]],

        [[-2.3613,  0.4606],
         [ 1.0304,  0.8362]],

        [[ 0.5690, -0.5754],
         [-0.3750, -0.7164]],

        [[-0.7991,  1.6169],
         [ 1.0393,  0.2280]],

        [[ 0.1741,  2.0219],
         [ 0.4011, -0.2367]]])

In [126]:
sum(sum(filters[0]*inputs[0,0:2,:2]))

tensor(-1.5472)

In [131]:
sum(sum(filters[0]*inputs[0,0:2,1:3]))

tensor(3.0578)

In [132]:
sum(sum(filters[1]*inputs[0,0:2,:2]))

tensor(1.0647)

In [134]:
sum(sum(filters[2]*inputs[0,2:4,:2]))

tensor(-2.9396)

In [129]:
out

tensor([[[-1.5472,  3.0578],
         [ 1.0647, -3.7603],
         [-2.9396,  0.3557],
         [ 0.7460,  1.1228],
         [-1.5871, -1.9176],
         [ 0.2089, -2.0115]]])

In [136]:
inputs = torch.randn(1, 6, 3)
filters = torch.randn(3, 2, 2)
out = F.conv1d(inputs, filters, groups=3)
print(out.shape)

torch.Size([1, 3, 2])


In [137]:
filters

tensor([[[-0.2131, -0.3739],
         [ 0.1408, -1.5768]],

        [[-0.8509, -0.6289],
         [ 0.3753, -0.2520]],

        [[-0.2391,  0.0474],
         [-1.0111, -2.1080]]])

In [143]:
inputs = torch.randn(1, 6, 3)
inputs = inputs.view(-1, 3, 3)
filters = torch.randn(3, 1, 3)
out = F.conv1d(inputs, filters, padding=1, groups=3)
out = out.view(1, 6, 3)
print(out.shape)

torch.Size([1, 6, 3])


In [149]:
sum(sum(filters[0]*inputs[0,0,:3]))

tensor(-0.1740)

In [153]:
sum(sum(filters[1]*inputs[0,1,:3]))

tensor(-0.1060)

In [155]:
sum(sum(filters[2]*inputs[0,2,:3]))

tensor(0.1466)

In [159]:
sum(sum(filters[0]*inputs[1,0,:3]))

tensor(-1.6573)

In [161]:
sum(sum(filters[1]*inputs[1,1,:3]))

tensor(0.0046)

In [162]:
inputs =  inputs.view(1, 6, 3)

In [163]:
sum(sum(filters[0]*inputs[0,0,:3]))

tensor(-0.1740)

In [167]:
sum(sum(filters[0]*inputs[0,3,:3]))

tensor(-1.6573)

In [145]:
out

tensor([[[ 0.3233, -0.1740,  0.5302],
         [-0.5915, -0.1060,  1.5928],
         [-0.2367,  0.1466, -0.0522],
         [ 1.4719, -1.6573,  1.6423],
         [-0.2396,  0.0046,  0.6375],
         [-0.7814,  1.1654, -0.8154]]])

## Test LightweightConv

In [27]:
batch_size = 16
seq_len_t = 128
channel_d = 64
f_td = torch.rand(batch_size, seq_len_t, channel_d) # BWH
f_dt = torch.rand(batch_size, channel_d, seq_len_t) # BHW
print(f_td.shape, f_dt.shape)

torch.Size([16, 128, 64]) torch.Size([16, 64, 128])


In [None]:
class LightweightConv(nn.Module):
    '''Lightweight convolution from fairseq.
    Args:
        input_size: # of channels of the input and output
        kernel_size: convolution channels
        padding: padding
        num_heads: number of heads used. The weight is of shape (num_heads, 1, kernel_size)
        weight_softmax: normalize the weight with softmax before the convolution
        dropout: dropout probability
    Forward:
        Input: BxCxT, i.e. (batch_size, input_size, timesteps)
        Output: BxCxT, i.e. (batch_size, input_size, timesteps)
    Attributes:
        weight: learnable weights of shape `(num_heads, 1, kernel_size)`
        bias:   learnable bias of shape `(input_size)`
    '''
    def __init__(self, input_size, kernel_size=1, padding=0, n_heads=1,
                 weight_softmax=True, bias=False, dropout=0.0):
        super().__init__()
        self.input_size = input_size
        self.kernel_size = kernel_size
        self.n_heads = n_heads
        self.padding = padding
        self.weight_softmax = weight_softmax
        self.weight = nn.Parameter(torch.Tensor(n_heads, 1, kernel_size))
        self.bias = nn.Parameter(torch.Tensor(input_size)) if bias else None
        self.dropout = dropout
        self.reset_parameters()

    def forward(self, input):
        B, C, T = input.size()
        H = self.n_heads
        weight = F.softmax(self.weight, dim=-1) if self.weight_softmax else self.weight
        weight = F.dropout(weight, self.weight_dropout, training=self.training)
        
        # Merge every C/H entries into the batch dimension (C = self.input_size)
        # B x C x T -> (B * C/H) x H x T
        # One can also expand the weight to C x 1 x K by a factor of C/H
        # and do not reshape the input instead, which is slow though
        input = input.view(-1, H, T)
        output = F.conv1d(input, weight, padding=self.padding, groups=H)
        output = output.view(B, C, T)
        if self.bias is not None:
            output = output + self.bias.view(1, -1, 1)
        return output

In [122]:
class LightweightConv(nn.Module):
    def __init__(self, input_size, groups, kernel_size=1, padding=0, n_heads=1,
                 weight_softmax=True, bias=False, dropout=0.0):
        super().__init__()
        self.input_size = input_size
        self.kernel_size = kernel_size
        self.n_heads = n_heads
        self.padding = padding
        self.groups = groups
        self.weight_softmax = weight_softmax
        self.weight = nn.Parameter(torch.Tensor(n_heads*input_size, np.int(input_size/groups), kernel_size))
        self.bias = nn.Parameter(torch.Tensor(input_size)) if bias else None
        self.dropout = dropout
        self.weight_dropout = 0.1
        self.training = True

    def forward(self, input):
        B, C, T = input.size()
        H = self.n_heads
        # weight: n_heads, 1, kernel_size 
        weight = F.softmax(self.weight, dim=-1) if self.weight_softmax else self.weight
        print(weight.shape)
        weight = F.dropout(weight, self.weight_dropout, training=self.training)
        # Merge every C/H entries into the batch dimension (C = self.input_size)
        # B x C x T -> (B * C/H) x H x T
        # One can also expand the weight to C x 1 x K by a factor of C/H
        # and do not reshape the input instead, which is slow though
        #input = input.contiguous().view(-1, H, T)
        # input_tensor, kernel, stride=1, padding=1
        output = F.conv1d(input, weight, padding=self.padding, groups=self.groups)
        #output = output.view(B, C, T)
        if self.bias is not None:
            output = output + self.bias.view(1, -1, 1)
        return output

In [123]:
model = LightweightConv(input_size=f_dt.shape[1],groups=f_dt.shape[1], kernel_size=3, padding=1, n_heads=3)
out = model(f_dt)
print(out.shape)

torch.Size([192, 1, 3])
torch.Size([16, 192, 128])


In [30]:
model.to(device)
model.eval()
input_tensor = f_td.to(device)

# Warm up the GPU by performing a few inference runs
for _ in range(5):
    model(input_tensor)

# Measure the inference time
start_time = time.time()
model(input_tensor)
end_time = time.time()
inference_time = end_time - start_time
print(f"Inference Time: {inference_time} seconds")

torch.Size([64, 1, 1])
torch.Size([32, 64, 64])
torch.Size([64, 1, 1])
torch.Size([32, 64, 64])
torch.Size([64, 1, 1])
torch.Size([32, 64, 64])
torch.Size([64, 1, 1])
torch.Size([32, 64, 64])
torch.Size([64, 1, 1])
torch.Size([32, 64, 64])
torch.Size([64, 1, 1])
torch.Size([32, 64, 64])
Inference Time: 0.0006244182586669922 seconds


In [31]:
macs, params = profile(model, inputs=(input_tensor, ))
macs, params = clever_format([macs, params], "%.3f")
print(macs, params)

torch.Size([64, 1, 1])
torch.Size([32, 64, 64])
0.000B 0.000B


## Test SymmetricConv

In [204]:
batch_size = 16
seq_len_t = 128
channel_d = 64
f_dt = torch.rand(batch_size, channel_d, seq_len_t) # BHW
print(f_dt.shape)

torch.Size([16, 64, 128])


In [92]:
def flip_half(output):
    B, C, T = output.size()
    half1 = output[:, :C//2, :]
    half2 = output[:, C//2:, :]
    half2_flipped = torch.flip(half2, dims=[1])
    output = torch.cat((half1, half2_flipped), dim=1)
    return output


class SymmetricLightweightConv(nn.Module):
    def __init__(self, d_size, groups=2, kernel_size=3, padding=1, n_heads=1,
                 weight_softmax=True, bias=False, dropout=0.0):
        super().__init__()
        self.input_size = d_size
        self.kernel_size = kernel_size
        self.n_heads = n_heads
        self.padding = padding
        self.groups = groups
        self.weight_softmax = weight_softmax
        self.weight = nn.Parameter(torch.Tensor(np.int(n_heads*self.input_size/2), 1, kernel_size))
        self.bias = nn.Parameter(torch.Tensor(self.input_size)) if bias else None
        self.dropout = dropout
        self.weight_dropout = 0.1
        self.training = True

    def forward(self, input):
        B, C, T = input.size()
        input = flip_half(input)
        reshaped_tensor = input.view(B * 2, C // 2, T)
        #print(reshaped_tensor.shape)
        H = self.n_heads
        # weight: n_heads, 1, kernel_size 
        #normalized_weights = (self.weight - torch.mean(self.weight)) / torch.std(self.weight)
        weight = F.softmax(self.weight, dim=-1) if self.weight_softmax else self.weight
        weight = F.dropout(weight, self.weight_dropout, training=self.training)
        #print(weight.shape)
        output = F.conv1d(reshaped_tensor, weight, padding=self.padding, groups=np.int(self.groups/2))
        #print(output.shape)
        output = output.view(B, C, T)
        output = flip_half(output)
        if self.bias is not None:
            output = output + self.bias.view(1, -1, 1)
        return output

In [95]:
model = SymmetricLightweightConv(d_size=f_dt.shape[1],groups=f_dt.shape[1], kernel_size=3, padding=1, n_heads=1, weight_softmax=False)
out = model(f_dt)
print(out.shape)

torch.Size([16, 224, 224])


## Test DynamicConv

In [221]:
batch_size = 16
seq_len_t = 128
channel_d = 64
f_td = torch.rand(batch_size, seq_len_t, channel_d) # BWH
f_dt = torch.rand(batch_size, channel_d, seq_len_t) # BHW
print(f_td.shape, f_dt.shape)

torch.Size([16, 128, 64]) torch.Size([16, 64, 128])


In [222]:
class DynamicConv(nn.Module):
    def __init__(self, input_size, kernel_size=1, padding=0, n_heads=1,
                 weight_softmax=True, bias=False, dropout=0.0):
        super().__init__()
        self.input_size = input_size
        self.kernel_size = kernel_size
        self.n_heads = n_heads
        self.padding = padding
        self.weight_softmax = weight_softmax
        self.weight = nn.Parameter(torch.Tensor(n_heads, 1, kernel_size))
        self.weight_linear = nn.Linear(input_size, n_heads * kernel_size, bias)
        self.bias = nn.Parameter(torch.Tensor(input_size)) if bias else None
        self.dropout = dropout
        self.weight_dropout = 0.1
        self.training = True

    def forward(self, input):
        '''Takes input (B x C x T) to output (B x C x T)'''
        
        # Prepare weight (take softmax)
        B, C, T = input.size()
        print(B,C,T)
        H, K = self.n_heads, self.kernel_size
        # weight: n_heads, 1, kernel_size 
        weight = self.weight_linear(input.permute(0,2,1))
        print(weight.shape)
        weight = F.softmax(weight, dim=-1) 
        #weight = F.softmax(self.weight, dim=-1) if self.weight_softmax else self.weight
        weight = F.dropout(weight, self.weight_dropout, training=self.training)
        weight = weight.permute(0, 2, 3, 1)
        weight = weight.contiguous().view(-1, H, T)
        input = input.contiguous().view(-1, H, T)
        # input_tensor, kernel, stride=1, padding=1
        output = F.conv1d(input, weight, padding=self.padding, groups=H)
        output = output.view(B, C, T)
        return output

In [223]:
model = DynamicConv(input_size=f_dt.shape[1], kernel_size=1, padding=0, n_heads=f_dt.shape[1])
out = model(f_dt)
print(out.shape)

16 64 128
torch.Size([16, 128, 64])


RuntimeError: number of dims don't match in permute

## Test Spectral Pooling

In [30]:
import torch
import torch.nn as nn
from torch.autograd import Function
import math
from torch.nn.modules.utils import _pair

def _spectral_crop(input, oheight, owidth):
    cutoff_freq_h = math.ceil(oheight / 2)
    cutoff_freq_w = math.ceil(owidth / 2)

    if oheight % 2 == 1:
        if owidth % 2 == 1:
            top_left = input[:, :, :cutoff_freq_h, :cutoff_freq_w]
            top_right = input[:, :, :cutoff_freq_h, -(cutoff_freq_w-1):]
            bottom_left = input[:, :, -(cutoff_freq_h-1):, :cutoff_freq_w]
            bottom_right = input[:, :, -(cutoff_freq_h-1):, -(cutoff_freq_w-1):]
        else:
            top_left = input[:, :, :cutoff_freq_h, :cutoff_freq_w]
            top_right = input[:, :, :cutoff_freq_h, -cutoff_freq_w:]
            bottom_left = input[:, :, -(cutoff_freq_h-1):, :cutoff_freq_w]
            bottom_right = input[:, :, -(cutoff_freq_h-1):, -cutoff_freq_w:]
    else:
        if owidth % 2 == 1:
            top_left = input[:, :, :cutoff_freq_h, :cutoff_freq_w]
            top_right = input[:, :, :cutoff_freq_h, -(cutoff_freq_w-1):]
            bottom_left = input[:, :, -cutoff_freq_h:, :cutoff_freq_w]
            bottom_right = input[:, :, -cutoff_freq_h:, -(cutoff_freq_w-1):]
        else:
            top_left = input[:, :, :cutoff_freq_h, :cutoff_freq_w]
            top_right = input[:, :, :cutoff_freq_h, -cutoff_freq_w:]
            bottom_left = input[:, :, -cutoff_freq_h:, :cutoff_freq_w]
            bottom_right = input[:, :, -cutoff_freq_h:, -cutoff_freq_w:]

    top_combined = torch.cat((top_left, top_right), dim=-1)
    bottom_combined = torch.cat((bottom_left, bottom_right), dim=-1)
    all_together = torch.cat((top_combined, bottom_combined), dim=-2)

    return all_together

def _spectral_pad(input, output, oheight, owidth):
    cutoff_freq_h = math.ceil(oheight / 2)
    cutoff_freq_w = math.ceil(owidth / 2)
    pad = torch.zeros_like(input)

    if oheight % 2 == 1:
        if owidth % 2 == 1:
            pad[:, :, :cutoff_freq_h, :cutoff_freq_w] = output[:, :, :cutoff_freq_h, :cutoff_freq_w]
            pad[:, :, :cutoff_freq_h, -(cutoff_freq_w-1):] = output[:, :, :cutoff_freq_h, -(cutoff_freq_w-1):]
            pad[:, :, -(cutoff_freq_h-1):, :cutoff_freq_w] = output[:, :, -(cutoff_freq_h-1):, :cutoff_freq_w]
            pad[:, :, -(cutoff_freq_h-1):, -(cutoff_freq_w-1):] = output[:, :, -(cutoff_freq_h-1):, -(cutoff_freq_w-1):]
        else:
            pad[:, :, :cutoff_freq_h, :cutoff_freq_w] = output[:, :, :cutoff_freq_h, :cutoff_freq_w]
            pad[:, :, :cutoff_freq_h, -cutoff_freq_w:] = output[:, :, :cutoff_freq_h, -cutoff_freq_w:]
            pad[:, :, -(cutoff_freq_h-1):, :cutoff_freq_w] = output[:, :, -(cutoff_freq_h-1):, :cutoff_freq_w]
            pad[:, :, -(cutoff_freq_h-1):, -cutoff_freq_w:] = output[:, :, -(cutoff_freq_h-1):, -cutoff_freq_w:]
    else:
        if owidth % 2 == 1:
            pad[:, :, :cutoff_freq_h, :cutoff_freq_w] = output[:, :, :cutoff_freq_h, :cutoff_freq_w]
            pad[:, :, :cutoff_freq_h, -(cutoff_freq_w-1):] = output[:, :, :cutoff_freq_h, -(cutoff_freq_w-1):]
            pad[:, :, -cutoff_freq_h:, :cutoff_freq_w] = output[:, :, -cutoff_freq_h:, :cutoff_freq_w]
            pad[:, :, -cutoff_freq_h:, -(cutoff_freq_w-1):] = output[:, :, -cutoff_freq_h:, -(cutoff_freq_w-1):]
        else:
            pad[:, :, :cutoff_freq_h, :cutoff_freq_w] = output[:, :, :cutoff_freq_h, :cutoff_freq_w]
            pad[:, :, :cutoff_freq_h, -cutoff_freq_w:] = output[:, :, :cutoff_freq_h, -cutoff_freq_w:]
            pad[:, :, -cutoff_freq_h:, :cutoff_freq_w] = output[:, :, -cutoff_freq_h:, :cutoff_freq_w]
            pad[:, :, -cutoff_freq_h:, -cutoff_freq_w:] = output[:, :, -cutoff_freq_h:, -cutoff_freq_w:]	

    return pad

def DiscreteHartleyTransform(input):
    fft = torch.rfft(input, 2, normalized=True, onesided=False)
    # for new version of pytorch
    #fft = torch.fft.fft2(input, dim=(-2, -1), norm='ortho')
    #fft = torch.stack((fft.real, fft.imag), -1)
    dht = fft[:, :, :, :, -2] - fft[:, :, :, :, -1]
    return dht

class SpectralPoolingFunction(Function):
    @staticmethod
    def forward(ctx, input, oheight, owidth):
        ctx.oh = oheight
        ctx.ow = owidth
        ctx.save_for_backward(input)

        # Hartley transform by RFFT
        dht = DiscreteHartleyTransform(input)

        # frequency cropping
        all_together = _spectral_crop(dht, oheight, owidth)
        # inverse Hartley transform
        dht = DiscreteHartleyTransform(all_together)
        return dht

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_variables

        # Hartley transform by RFFT
        dht = DiscreteHartleyTransform(grad_output)
        # frequency padding
        grad_input = _spectral_pad(input, dht, ctx.oh, ctx.ow)
        # inverse Hartley transform
        grad_input = DiscreteHartleyTransform(grad_input)
        return grad_input, None, None

class SpectralPool2d(nn.Module):
    def __init__(self, t_size):
        super(SpectralPool2d, self).__init__()
        self.t_size = t_size
    def forward(self, input):
        H, W = input.size(-2), input.size(-1)
        #h, w = math.ceil(H*self.scale_factor[0]), math.ceil(W*self.scale_factor[1])
        return SpectralPoolingFunction.apply(input, H, self.t_size)



class SpectralPooling_layer(nn.Module):
    def __init__(self, t_size):
        super(SpectralPooling_layer, self).__init__()
        self.t_size = t_size
        self.SpecPool2d = SpectralPool2d(t_size=t_size)

    def forward(self, x):
        # input: batch, in_channel, length 
        x = x.unsqueeze(1)   # input: batch, 1, in_channel, length 
        out = self.SpecPool2d(x)
        return out.squeeze()

In [31]:
model = SpectralPooling_layer(t_size=128)


batch_size = 16
seq_len_t = 153
channel_d = 150
f_dt = torch.rand(batch_size, channel_d, seq_len_t) # BCT
print(f_dt.shape)
out =  model(f_dt)
print(out.shape)

torch.Size([16, 150, 153])
torch.Size([16, 150, 128])


## Test Conv1D Encoder

In [64]:
class Conv1DEncoder(nn.Module):
    def __init__(self, in_channel):
        super(Conv1DEncoder, self).__init__()
        self.conv1 = nn.Sequential(
            # input: batch, in_channel, length 
            # conv1d: in_channel, out_channel, kernel, stride, padding
            # size: (in_size-kernel+2*padding)/stride + 1
            # l_out: batch, out_channel, length

            nn.Conv1d(in_channel, 32, 5, 1, 2),  # out: batch * 32 * 112
            nn.BatchNorm1d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool1d(2),  # batch * 32 * 56

        )
        self.conv2 = nn.Sequential(nn.Conv1d(32, 64, 3, 1, 1),  # out: batch * 64 * 56
            nn.BatchNorm1d(64),
            nn.ReLU(inplace=True),
        )
        self.conv3 = nn.Sequential(nn.Conv1d(64, 128, 3, 1, 1),  # out: batch * 64 * 56
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
        )
        

    def forward(self, x):
        # input x:  batch * C * T
        y = self.conv1(x)  # batch * 64 * 112
        #print(y.shape)
        y = self.conv2(y)
        #y = self.conv3(y)
        
        return y

In [66]:
batch_size = 16
seq_len_t = 153
channel_d = 150
f_dt = torch.rand(batch_size, channel_d, seq_len_t) # BHW
print(f_dt.shape)
model = Conv1DEncoder(in_channel = channel_d)
out =  model(f_dt)
print(out.shape)

torch.Size([16, 150, 153])
torch.Size([16, 64, 76])


## Test Attenntionl LSTM

In [58]:
class Attention1D(nn.Module):
    def __init__(self, in_channel):
        super(Attention1D, self).__init__()
        self.tanh = nn.Tanh()
        self.weight = nn.Linear(in_channel,1)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, H):
        M = self.tanh(H)  # (batch, seq_len, rnn_size)
        alpha = self.weight(M).squeeze(2)  # (batch, seq_len)
        alpha = self.softmax(alpha)  # (batch, seq_len)

        r = H * alpha.unsqueeze(2) # (batch, seq_len, rnn_size)
        r = r.sum(dim=1)  # (batch, rnn_size)

        return r, alpha 

class Attentional_LSTM_Pool(nn.Module):
    def __init__(self, input_size, num_classes):
        super(Attentional_LSTM_Pool,self).__init__()
        
        self.attention = Attention1D(in_channel=32)
        self.lstm = nn.LSTM(input_size=64, hidden_size=32, num_layers=3, batch_first=True, bidirectional=False)
        self.fc = nn.Linear(96, num_classes)

    def forward(self,x):
        
        # x:  B * T * C
        x1, (ht,ct) = self.lstm(x) # x1: B, T, bi*hidden_size
        x1, _ = self.attention(x1) # out: batch, bi*hidden_size: 64
        x2 = torch.max(x, 1, keepdim=False)[0] #  B * C
        x_all = torch.cat((x1,x2),dim=1)
        out = self.fc(x_all)
        
        return out

In [63]:
batch_size = 16
seq_len_t = 76
channel_d = 64
f_dt = torch.rand(batch_size, channel_d, seq_len_t) # BCT
f_td = f_dt.transpose(2,1)
print(f_td.shape)
model = Attentional_LSTM_Pool(input_size = channel_d, num_classes = 5)
out =  model(f_td)
print(out.shape)

torch.Size([16, 76, 64])
torch.Size([16, 5])


## Test Rank

In [89]:
import numpy as np
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from torch.nn.modules.utils import _pair


def _spectral_crop(input, oheight, owidth):
    cutoff_freq_h = math.ceil(oheight / 2)
    cutoff_freq_w = math.ceil(owidth / 2)

    if oheight % 2 == 1:
        if owidth % 2 == 1:
            top_left = input[:, :, :cutoff_freq_h, :cutoff_freq_w]
            top_right = input[:, :, :cutoff_freq_h, -(cutoff_freq_w-1):]
            bottom_left = input[:, :, -(cutoff_freq_h-1):, :cutoff_freq_w]
            bottom_right = input[:, :, -(cutoff_freq_h-1):, -(cutoff_freq_w-1):]
        else:
            top_left = input[:, :, :cutoff_freq_h, :cutoff_freq_w]
            top_right = input[:, :, :cutoff_freq_h, -cutoff_freq_w:]
            bottom_left = input[:, :, -(cutoff_freq_h-1):, :cutoff_freq_w]
            bottom_right = input[:, :, -(cutoff_freq_h-1):, -cutoff_freq_w:]
    else:
        if owidth % 2 == 1:
            top_left = input[:, :, :cutoff_freq_h, :cutoff_freq_w]
            top_right = input[:, :, :cutoff_freq_h, -(cutoff_freq_w-1):]
            bottom_left = input[:, :, -cutoff_freq_h:, :cutoff_freq_w]
            bottom_right = input[:, :, -cutoff_freq_h:, -(cutoff_freq_w-1):]
        else:
            top_left = input[:, :, :cutoff_freq_h, :cutoff_freq_w]
            top_right = input[:, :, :cutoff_freq_h, -cutoff_freq_w:]
            bottom_left = input[:, :, -cutoff_freq_h:, :cutoff_freq_w]
            bottom_right = input[:, :, -cutoff_freq_h:, -cutoff_freq_w:]

    top_combined = torch.cat((top_left, top_right), dim=-1)
    bottom_combined = torch.cat((bottom_left, bottom_right), dim=-1)
    all_together = torch.cat((top_combined, bottom_combined), dim=-2)

    return all_together

def _spectral_pad(input, output, oheight, owidth):
    cutoff_freq_h = math.ceil(oheight / 2)
    cutoff_freq_w = math.ceil(owidth / 2)
    pad = torch.zeros_like(input)

    if oheight % 2 == 1:
        if owidth % 2 == 1:
            pad[:, :, :cutoff_freq_h, :cutoff_freq_w] = output[:, :, :cutoff_freq_h, :cutoff_freq_w]
            pad[:, :, :cutoff_freq_h, -(cutoff_freq_w-1):] = output[:, :, :cutoff_freq_h, -(cutoff_freq_w-1):]
            pad[:, :, -(cutoff_freq_h-1):, :cutoff_freq_w] = output[:, :, -(cutoff_freq_h-1):, :cutoff_freq_w]
            pad[:, :, -(cutoff_freq_h-1):, -(cutoff_freq_w-1):] = output[:, :, -(cutoff_freq_h-1):, -(cutoff_freq_w-1):]
        else:
            pad[:, :, :cutoff_freq_h, :cutoff_freq_w] = output[:, :, :cutoff_freq_h, :cutoff_freq_w]
            pad[:, :, :cutoff_freq_h, -cutoff_freq_w:] = output[:, :, :cutoff_freq_h, -cutoff_freq_w:]
            pad[:, :, -(cutoff_freq_h-1):, :cutoff_freq_w] = output[:, :, -(cutoff_freq_h-1):, :cutoff_freq_w]
            pad[:, :, -(cutoff_freq_h-1):, -cutoff_freq_w:] = output[:, :, -(cutoff_freq_h-1):, -cutoff_freq_w:]
    else:
        if owidth % 2 == 1:
            pad[:, :, :cutoff_freq_h, :cutoff_freq_w] = output[:, :, :cutoff_freq_h, :cutoff_freq_w]
            pad[:, :, :cutoff_freq_h, -(cutoff_freq_w-1):] = output[:, :, :cutoff_freq_h, -(cutoff_freq_w-1):]
            pad[:, :, -cutoff_freq_h:, :cutoff_freq_w] = output[:, :, -cutoff_freq_h:, :cutoff_freq_w]
            pad[:, :, -cutoff_freq_h:, -(cutoff_freq_w-1):] = output[:, :, -cutoff_freq_h:, -(cutoff_freq_w-1):]
        else:
            pad[:, :, :cutoff_freq_h, :cutoff_freq_w] = output[:, :, :cutoff_freq_h, :cutoff_freq_w]
            pad[:, :, :cutoff_freq_h, -cutoff_freq_w:] = output[:, :, :cutoff_freq_h, -cutoff_freq_w:]
            pad[:, :, -cutoff_freq_h:, :cutoff_freq_w] = output[:, :, -cutoff_freq_h:, :cutoff_freq_w]
            pad[:, :, -cutoff_freq_h:, -cutoff_freq_w:] = output[:, :, -cutoff_freq_h:, -cutoff_freq_w:]	

    return pad

def DiscreteHartleyTransform(input):
    fft = torch.rfft(input, 2, normalized=True, onesided=False)
    # for new version of pytorch
    #fft = torch.fft.fft2(input, dim=(-2, -1), norm='ortho')
    #fft = torch.stack((fft.real, fft.imag), -1)
    dht = fft[:, :, :, :, -2] - fft[:, :, :, :, -1]
    return dht

class SpectralPoolingFunction(Function):
    @staticmethod
    def forward(ctx, input, oheight, owidth):
        ctx.oh = oheight
        ctx.ow = owidth
        ctx.save_for_backward(input)

        # Hartley transform by RFFT
        dht = DiscreteHartleyTransform(input)

        # frequency cropping
        all_together = _spectral_crop(dht, oheight, owidth)
        # inverse Hartley transform
        dht = DiscreteHartleyTransform(all_together)
        return dht

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_variables

        # Hartley transform by RFFT
        dht = DiscreteHartleyTransform(grad_output)
        # frequency padding
        grad_input = _spectral_pad(input, dht, ctx.oh, ctx.ow)
        # inverse Hartley transform
        grad_input = DiscreteHartleyTransform(grad_input)
        return grad_input, None, None

class SpectralPool2d(nn.Module):
    def __init__(self, t_size):
        super(SpectralPool2d, self).__init__()
        self.t_size = t_size
    def forward(self, input):
        H, W = input.size(-2), input.size(-1)
        #h, w = math.ceil(H*self.scale_factor[0]), math.ceil(W*self.scale_factor[1])
        return SpectralPoolingFunction.apply(input, H, self.t_size)



class SpectralPooling_layer(nn.Module):
    def __init__(self, t_size):
        super(SpectralPooling_layer, self).__init__()
        self.t_size = t_size
        self.SpecPool2d = SpectralPool2d(t_size=t_size)

    def forward(self, x):
        # input: batch, in_channel, length 
        x = x.unsqueeze(1)   # input: batch, 1, in_channel, length 
        out = self.SpecPool2d(x)
        return out.squeeze()
    
    
class SymmetricLightweightConv(nn.Module):
    def __init__(self, d_size, groups=2, kernel_size=3, padding=1, n_heads=1,
                 weight_softmax=True, bias=False, dropout=0.0):
        super().__init__()
        self.input_size = d_size
        self.kernel_size = kernel_size
        self.n_heads = n_heads
        self.padding = padding
        self.groups = groups
        self.weight_softmax = weight_softmax
        self.weight = nn.Parameter(torch.Tensor(np.int(n_heads*self.input_size/2), 1, kernel_size))
        self.bias = nn.Parameter(torch.Tensor(self.input_size)) if bias else None
        self.dropout = dropout
        self.weight_dropout = 0.1
        self.training = True

    def forward(self, input):
        B, C, T = input.size()
        half1 = input[:, :C//2, :]
        half2 = input[:, C//2:, :]
        half2_flipped = torch.flip(half2, dims=[2])
        concatenated_tensor = torch.cat((half1, half2_flipped), dim=1)
        reshaped_tensor = concatenated_tensor.view(B * 2, C // 2, T)
        #print(reshaped_tensor.shape)
        H = self.n_heads
        # weight: n_heads, 1, kernel_size 
        weight = F.softmax(self.weight, dim=-1) if self.weight_softmax else self.weight
        weight = F.dropout(weight, self.weight_dropout, training=self.training)
        #print(weight.shape)
        output = F.conv1d(reshaped_tensor, weight, padding=self.padding, groups=np.int(self.groups/2))
        #print(output.shape)
        output = output.view(B, C, T)
        if self.bias is not None:
            output = output + self.bias.view(1, -1, 1)
        return output

    
class Conv1DEncoder(nn.Module):
    def __init__(self, d_size):
        super(Conv1DEncoder, self).__init__()
        self.conv1 = nn.Sequential(
            # input: batch, in_channel, length 
            # conv1d: in_channel, out_channel, kernel, stride, padding
            # size: (in_size-kernel+2*padding)/stride + 1
            # l_out: batch, out_channel, length

            nn.Conv1d(d_size, 32, 5, 1, 2),  # out: batch * 32 * 112
            nn.BatchNorm1d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool1d(2),  # batch * 32 * 56

        )
        self.conv2 = nn.Sequential(nn.Conv1d(32, 64, 3, 1, 1),  # out: batch * 64 * 56
            nn.BatchNorm1d(64),
            nn.ReLU(inplace=True),
        )
        self.conv3 = nn.Sequential(nn.Conv1d(64, 128, 3, 1, 1),  # out: batch * 64 * 56
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
        )
        

    def forward(self, x):
        # input x:  batch * C * T
        y = self.conv1(x)  # batch * 64 * 112
        #print(y.shape)
        y = self.conv2(y)
        #y = self.conv3(y)
        
        return y
    
    
class Attention1D(nn.Module):
    def __init__(self, in_channel):
        super(Attention1D, self).__init__()
        self.tanh = nn.Tanh()
        self.weight = nn.Linear(in_channel,1)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, H):
        M = self.tanh(H)  # (batch, seq_len, rnn_size)
        alpha = self.weight(M).squeeze(2)  # (batch, seq_len)
        alpha = self.softmax(alpha)  # (batch, seq_len)

        r = H * alpha.unsqueeze(2) # (batch, seq_len, rnn_size)
        r = r.sum(dim=1)  # (batch, rnn_size)

        return r, alpha 

    
class Attentional_LSTM_Pool(nn.Module):
    def __init__(self, d_size):
        super(Attentional_LSTM_Pool,self).__init__()
        
        self.attention = Attention1D(in_channel=64)
        self.lstm = nn.LSTM(input_size=d_size, hidden_size=32, num_layers=3, batch_first=True, bidirectional=True)

    def forward(self,x):
        
        # x:  B * T * C
        x1, (ht,ct) = self.lstm(x) # x1: B, T, bi*hidden_size
        x1, _ = self.attention(x1) # out: batch, bi*hidden_size: 64
        x2 = torch.max(x, 1, keepdim=False)[0] #  B * C
        out = torch.cat((x1,x2),dim=1)        
        return out, x1


class Conv1DLSTM_All(nn.Module):
    def __init__(self, d_size, t_size, num_classes=9):
        super(Conv1DLSTM_All, self).__init__()
        self.d_size = d_size
        self.t_size = t_size
        self.conv0 = SymmetricLightweightConv(d_size=self.d_size,groups=self.d_size, kernel_size=3, padding=1, n_heads=1)
        self.pool = SpectralPooling_layer(self.t_size)
        self.conv1 = Conv1DEncoder(d_size = self.d_size)
        self.attention_lstm_pool = Attentional_LSTM_Pool(d_size = 64)
        self.fc = nn.Linear(64+64, num_classes)


    def forward(self, x):
        # input x: batch * 1 * 224 * 224
        y = x.squeeze(1)  # batch * C * T
        y = self.conv0(y)
        y = self.pool(y) # batch * C * T
        y = self.conv1(y)  # batch * C * T
        y = y.transpose(2, 1)  # batch * T(64) * C(128)
        out, f_st = self.attention_lstm_pool(y)  # out: batch * 112 * 64
        out = self.fc(out)
        print(f_st.shape, y.shape)
        
        return out, f_st, y
    

def smoothSeq(seq):
    cumulative_sum = torch.cumsum(seq, dim=1)
    accumulated_time = torch.arange(1, seq.size(1) + 1, dtype=seq.dtype, device=seq.device)
    smoothed_seq = cumulative_sum / accumulated_time.view(1, seq.size(1), 1)
    return smoothed_seq


def softplus(x):
    return torch.log(1 + torch.exp(x))


def rank_loss(f_st, f, beta):
    loss = 0
    _, length, feature_size = f.shape
    f_smooth = smoothSeq(f)
    for i in range(length-1):
        theta = torch.sum(f_st.squeeze() * f_smooth[:, i+1, :].squeeze(), dim=1) - torch.sum(f_st.squeeze() * f_smooth[:, i, :].squeeze(), dim=1) + beta
        time_loss = softplus(theta) 
        #print(time_loss)
        #print(loss)
        loss += time_loss
    loss /=  length-1
    return torch.mean(loss)

In [90]:
import torch
import torch.nn as nn


class Attention1D(nn.Module):
    def __init__(self, in_channel:int):
        super(Attention1D, self).__init__()
        self.tanh = nn.Tanh()
        self.weight = nn.Linear(in_channel,1)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, H):
        M = self.tanh(H)  # (batch, seq_len, rnn_size)  seq_len可以理解为时间维 rnn_size为lstm输入
        alpha = self.weight(M).squeeze(2)  # (batch, seq_len)
        alpha = self.softmax(alpha)  # (batch, seq_len)

        r = H * alpha.unsqueeze(2) # (batch, seq_len, rnn_size)
        r = r.sum(dim=1)  # (batch, rnn_size)

        return r, alpha

    
class Conv1DLSTM(nn.Module):
    def __init__(self, num_classes=9, lstm_type='plain'):
        super(Conv1DLSTM, self).__init__()
        self.lstm_type = lstm_type
        self.conv1 = nn.Sequential(
            # input: batch, in_channel, length 
            # conv1d: in_channel, out_channel, kernel, stride, padding
            # size: (in_size-kernel+2*padding)/stride + 1
            # l_out: batch, out_channel, length
            nn.Conv1d(224, 32, 4, 2, 1),  # out: batch * 32 * 112
            nn.BatchNorm1d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool1d(2, 2),  # batch * 32 * 56
            nn.Conv1d(32, 64, 3, 1, 1),  # out: batch * 64 * 56
            nn.BatchNorm1d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool1d(3, 1, 1),  # batch * 64 * 56
            nn.ReLU(inplace=True),
        )
        self.lstm = nn.LSTM(input_size=64, hidden_size=32, num_layers=3, batch_first=True, bidirectional=True)
        self.attention = Attention1D(in_channel=64)
        self.fc = nn.Linear(64, num_classes)
        self.fc1 = nn.Linear(384, num_classes)
        

    def forward(self, x):
        # input x: batch * 1 * 224 * 224
        y = x.squeeze(1)  # batch * 224 * 224
        y = self.conv1(y)  # batch * 64 * 112
        y = y.transpose(2, 1)  # batch * 112 * 64
        #lstm in: batch, length, feature_in
        #lstm out: batch, length, feature_out * 2(bi)
   
        if self.lstm_type == 'plain':
            out, hidden = self.lstm(y)  # out: batch * 112 * 64
            f_st = out[:, -1, :] # out: batch * 64
            out = self.fc(f_st)
            
        elif self.lstm_type == 'attention':
            out, hidden = self.lstm(y)  # out: batch * 112 * 64
            _, alpha = self.attention(out) # out: batch * 64
            out = out * alpha.unsqueeze(2)
            #print(out.shape)
            f_st = out.sum(dim=1)  # (batch, rnn_size)
            out = self.fc(f_st)
        
        return out, f_st, y
    

In [91]:
batch_size = 16
seq_len_t = 224
channel_d = 224
model = Conv1DLSTM_All(d_size=channel_d, t_size=128, num_classes=9)
f_dt = torch.rand(batch_size, channel_d, seq_len_t) # BHW
out, f_st, y  =  model(f_dt)
print(f_st.shape, y.shape)
rank_loss(f_st, y, 0.1)

torch.Size([16, 64]) torch.Size([16, 64, 64])
torch.Size([16, 64]) torch.Size([16, 64, 64])


tensor(nan, grad_fn=<MeanBackward0>)

In [76]:
batch_size = 16
seq_len_t = 224
channel_d = 224
model = Conv1DLSTM(num_classes=9, lstm_type='attention')
f_dt = torch.rand(batch_size, channel_d, seq_len_t) # BHW
out, f_st, y  =  model(f_dt)
print(f_st.shape, y.shape)
rank_loss(f_st, y, 0.1)

torch.Size([16, 64]) torch.Size([16, 56, 64])


tensor(0.7426, grad_fn=<MeanBackward0>)

## Test All

In [60]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from pooling import Pooling_layer
import numpy as np


class SymmetricLightweightConv(nn.Module):
    def __init__(self, d_size, groups=2, kernel_size=3, padding=1, n_heads=1,
                 weight_softmax=True, bias=False, dropout=0.0):
        super().__init__()
        self.input_size = d_size
        self.kernel_size = kernel_size
        self.n_heads = n_heads
        self.padding = padding
        self.groups = groups
        self.weight_softmax = weight_softmax
        self.weight = nn.Parameter(torch.Tensor(np.int(n_heads*self.input_size/2), 1, kernel_size))
        self.bias = nn.Parameter(torch.Tensor(self.input_size)) if bias else None
        self.dropout = dropout
        self.weight_dropout = 0.1
        self.training = True

    def forward(self, input):
        B, C, T = input.size()
        half1 = input[:, :C//2, :]
        half2 = input[:, C//2:, :]
        half2_flipped = torch.flip(half2, dims=[2])
        concatenated_tensor = torch.cat((half1, half2_flipped), dim=1)
        reshaped_tensor = concatenated_tensor.view(B * 2, C // 2, T)
        #print(reshaped_tensor.shape)
        H = self.n_heads
        # weight: n_heads, 1, kernel_size 
        weight = F.softmax(self.weight, dim=-1) if self.weight_softmax else self.weight
        weight = F.dropout(weight, self.weight_dropout, training=self.training)
        #print(weight.shape)
        output = F.conv1d(reshaped_tensor, weight, padding=self.padding, groups=np.int(self.groups/2))
        #print(output.shape)
        output = output.view(B, C, T)
        if self.bias is not None:
            output = output + self.bias.view(1, -1, 1)
        return output

    
class Conv1DEncoder(nn.Module):
    def __init__(self, d_size):
        super(Conv1DEncoder, self).__init__()
        self.conv1 = nn.Sequential(
            # input: batch, in_channel, length 
            # conv1d: in_channel, out_channel, kernel, stride, padding
            # size: (in_size-kernel+2*padding)/stride + 1
            # l_out: batch, out_channel, length

            nn.Conv1d(d_size, 32, 5, 1, 2),  # out: batch * 32 * 112
            nn.BatchNorm1d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool1d(2),  # batch * 32 * 56

        )
        self.conv2 = nn.Sequential(nn.Conv1d(32, 64, 3, 1, 1),  # out: batch * 64 * 56
            nn.BatchNorm1d(64),
            nn.ReLU(inplace=True),
        )
        self.conv3 = nn.Sequential(nn.Conv1d(64, 128, 3, 1, 1),  # out: batch * 64 * 56
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
        )
        

    def forward(self, x):
        # input x:  batch * C * T
        y = self.conv1(x)  # batch * 64 * 112
        #print(y.shape)
        y = self.conv2(y)
        y = self.conv3(y)
        
        return y
    
    
class Attention1D(nn.Module):
    def __init__(self, in_channel):
        super(Attention1D, self).__init__()
        self.tanh = nn.Tanh()
        self.weight = nn.Linear(in_channel,1)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, H):
        M = self.tanh(H)  # (batch, seq_len, rnn_size)
        alpha = self.weight(M).squeeze(2)  # (batch, seq_len)
        alpha = self.softmax(alpha)  # (batch, seq_len)

        r = H * alpha.unsqueeze(2) # (batch, seq_len, rnn_size)
        r = r.sum(dim=1)  # (batch, rnn_size)

        return r, alpha 

    
class Attentional_LSTM_Pool(nn.Module):
    def __init__(self, d_size):
        super(Attentional_LSTM_Pool,self).__init__()
        
        self.attention = Attention1D(in_channel=32)
        self.lstm = nn.LSTM(input_size=d_size, hidden_size=32, num_layers=3, batch_first=True, bidirectional=False)

    def forward(self,x):
        
        # x:  B * T * C
        x1, (ht,ct) = self.lstm(x) # x1: B, T, bi*hidden_size
        x1, _ = self.attention(x1) # out: batch, bi*hidden_size: 64
        x2 = torch.max(x, 1, keepdim=False)[0] #  B * C
        out = torch.cat((x1,x2),dim=1)        
        return out


class Conv1DLSTM_All(nn.Module):
    def __init__(self, d_size, t_size, num_classes=9):
        super(Conv1DLSTM_All, self).__init__()
        self.d_size = d_size
        self.t_size = t_size
        self.conv0 = SymmetricLightweightConv(d_size=self.d_size,groups=self.d_size, kernel_size=3, padding=1, n_heads=1)
        self.pool = SpectralPooling_layer(self.t_size)
        self.conv1 = Conv1DEncoder(d_size = self.d_size)
        self.attention_lstm_pool = Attentional_LSTM_Pool(d_size = 128)
        self.fc = nn.Linear(160, num_classes)


    def forward(self, x):
        # input x: batch * 1 * 224 * 224
        y = x.squeeze(1)  # batch * C * T
        y = self.conv0(y)
        y = self.pool(y)
        y = self.conv1(y)  # batch * 64 * 112
        y = y.transpose(2, 1)  # batch * 112 * 64
        out = self.attention_lstm_pool(y)  # out: batch * 112 * 64
        out = self.fc(out)
        
        return out

In [61]:
batch_size = 16
seq_len_t = 224
channel_d = 224
model = Conv1DLSTM_All(d_size=channel_d, t_size=128, num_classes=9)
f_dt = torch.rand(batch_size, channel_d, seq_len_t) # BHW
out =  model(f_dt)
print(out.shape)

torch.Size([16, 9])


In [None]:
import torch
import torch.nn as nn
from torch.autograd import Function
import math
from torch.nn.modules.utils import _pair

def _spectral_crop(input, oheight, owidth):
    cutoff_freq_h = math.ceil(oheight / 2)
    cutoff_freq_w = math.ceil(owidth / 2)

    if oheight % 2 == 1:
        if owidth % 2 == 1:
            top_left = input[:, :, :cutoff_freq_h, :cutoff_freq_w]
            top_right = input[:, :, :cutoff_freq_h, -(cutoff_freq_w-1):]
            bottom_left = input[:, :, -(cutoff_freq_h-1):, :cutoff_freq_w]
            bottom_right = input[:, :, -(cutoff_freq_h-1):, -(cutoff_freq_w-1):]
        else:
            top_left = input[:, :, :cutoff_freq_h, :cutoff_freq_w]
            top_right = input[:, :, :cutoff_freq_h, -cutoff_freq_w:]
            bottom_left = input[:, :, -(cutoff_freq_h-1):, :cutoff_freq_w]
            bottom_right = input[:, :, -(cutoff_freq_h-1):, -cutoff_freq_w:]
    else:
        if owidth % 2 == 1:
            top_left = input[:, :, :cutoff_freq_h, :cutoff_freq_w]
            top_right = input[:, :, :cutoff_freq_h, -(cutoff_freq_w-1):]
            bottom_left = input[:, :, -cutoff_freq_h:, :cutoff_freq_w]
            bottom_right = input[:, :, -cutoff_freq_h:, -(cutoff_freq_w-1):]
        else:
            top_left = input[:, :, :cutoff_freq_h, :cutoff_freq_w]
            top_right = input[:, :, :cutoff_freq_h, -cutoff_freq_w:]
            bottom_left = input[:, :, -cutoff_freq_h:, :cutoff_freq_w]
            bottom_right = input[:, :, -cutoff_freq_h:, -cutoff_freq_w:]

    top_combined = torch.cat((top_left, top_right), dim=-1)
    bottom_combined = torch.cat((bottom_left, bottom_right), dim=-1)
    all_together = torch.cat((top_combined, bottom_combined), dim=-2)

    return all_together

def _spectral_pad(input, output, owidth):
    cutoff_freq_w = math.ceil(owidth / 2)
    pad = torch.zeros_like(input)

    if oheight % 2 == 1:
        if owidth % 2 == 1:
            pad[:, :, :cutoff_freq_h, :cutoff_freq_w] = output[:, :, :cutoff_freq_h, :cutoff_freq_w]
            pad[:, :, :cutoff_freq_h, -(cutoff_freq_w-1):] = output[:, :, :cutoff_freq_h, -(cutoff_freq_w-1):]
            pad[:, :, -(cutoff_freq_h-1):, :cutoff_freq_w] = output[:, :, -(cutoff_freq_h-1):, :cutoff_freq_w]
            pad[:, :, -(cutoff_freq_h-1):, -(cutoff_freq_w-1):] = output[:, :, -(cutoff_freq_h-1):, -(cutoff_freq_w-1):]
        else:
            pad[:, :, :cutoff_freq_h, :cutoff_freq_w] = output[:, :, :cutoff_freq_h, :cutoff_freq_w]
            pad[:, :, :cutoff_freq_h, -cutoff_freq_w:] = output[:, :, :cutoff_freq_h, -cutoff_freq_w:]
            pad[:, :, -(cutoff_freq_h-1):, :cutoff_freq_w] = output[:, :, -(cutoff_freq_h-1):, :cutoff_freq_w]
            pad[:, :, -(cutoff_freq_h-1):, -cutoff_freq_w:] = output[:, :, -(cutoff_freq_h-1):, -cutoff_freq_w:]
    else:
        if owidth % 2 == 1:
            pad[:, :, :cutoff_freq_h, :cutoff_freq_w] = output[:, :, :cutoff_freq_h, :cutoff_freq_w]
            pad[:, :, :cutoff_freq_h, -(cutoff_freq_w-1):] = output[:, :, :cutoff_freq_h, -(cutoff_freq_w-1):]
            pad[:, :, -cutoff_freq_h:, :cutoff_freq_w] = output[:, :, -cutoff_freq_h:, :cutoff_freq_w]
            pad[:, :, -cutoff_freq_h:, -(cutoff_freq_w-1):] = output[:, :, -cutoff_freq_h:, -(cutoff_freq_w-1):]
        else:
            pad[:, :, :cutoff_freq_h, :cutoff_freq_w] = output[:, :, :cutoff_freq_h, :cutoff_freq_w]
            pad[:, :, :cutoff_freq_h, -cutoff_freq_w:] = output[:, :, :cutoff_freq_h, -cutoff_freq_w:]
            pad[:, :, -cutoff_freq_h:, :cutoff_freq_w] = output[:, :, -cutoff_freq_h:, :cutoff_freq_w]
            pad[:, :, -cutoff_freq_h:, -cutoff_freq_w:] = output[:, :, -cutoff_freq_h:, -cutoff_freq_w:]	

    return pad

def DiscreteHartleyTransform(input):
    fft = torch.rfft(input, 2, normalized=True, onesided=False)
    # for new version of pytorch
    #fft = torch.fft.fft2(input, dim=(-2, -1), norm='ortho')
    #fft = torch.stack((fft.real, fft.imag), -1)
    dht = fft[:, :, :, :, -2] - fft[:, :, :, :, -1]
    return dht

class SpectralPoolingFunction(Function):
    @staticmethod
    def forward(ctx, input, owidth):
        ctx.ow = owidth
        ctx.save_for_backward(input)
        # Hartley transform by RFFT
        dht = DiscreteHartleyTransform(input)
        # frequency cropping
        all_together = _spectral_crop(dht, owidth)
        # inverse Hartley transform
        dht = DiscreteHartleyTransform(all_together)
        return dht

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_variables

        # Hartley transform by RFFT
        dht = DiscreteHartleyTransform(grad_output)
        # frequency padding
        grad_input = _spectral_pad(input, dht, ctx.ow)
        # inverse Hartley transform
        grad_input = DiscreteHartleyTransform(grad_input)
        return grad_input, None, None

class SpectralPool1d(nn.Module):
    def __init__(self, t_size):
        super(SpectralPool2d, self).__init__()
        self.t_size = t_size
    def forward(self, input):
        #H, W = input.size(-2), input.size(-1)
        #h, w = math.ceil(H*self.scale_factor[0]), math.ceil(W*self.scale_factor[1])
        return SpectralPoolingFunction.apply(input, self.t_size)



class SpectralPooling_layer(nn.Module):
    def __init__(self, t_size):
        super(SpectralPooling_layer, self).__init__()
        self.t_size = t_size
        self.SpecPool2d = SpectralPool2d(t_size=t_size)

    def forward(self, x):
        # input: batch, in_channel, length 
        x = x.unsqueeze(1)   # input: batch, 1, in_channel, length 
        out = self.SpecPool1d(x)
        return out.squeeze()