In [None]:
class FreqLinear(torch.nn.Module):
    def __init__(self, in_channels, fft_len, kernal_size) -> None:
        super().__init__()
        self.kernel = torch.nn.Parameter(
            torch.complex(
                torch.ones(fft_len, kernal_size), torch.zeros(fft_len, kernal_size)
            ),
            requires_grad=True,
        )
        self.linear = torch.nn.Linear(kernal_size * in_channels , in_channels).to(torch.cfloat)
        self.in_channels= in_channels
        self.fft_len = fft_len
        
    def forward(self, x):
        # x = x.permute(0,2,1)
        x = torch.fft.rfft(x, norm="ortho", dim=1)
        x_channels = []
        for i in range(x.shape[-1]):
            x_channels.append(x[...,[i]] * self.kernel)
        x = torch.concat(x_channels, dim=-1)
        x = self.linear(x)
        # x = self.linear(x.flatten()).reshape(-1, self.fft_len, self.in_channels)
        x = torch.fft.irfft(x, norm="ortho", dim=1)
        return x


class TimeConv(torch.nn.Module):
    def __init__(self, in_channel) -> None:
        super().__init__()
        self.conv = torch.nn.Conv1d(
            in_channel, out_channels=in_channel, kernel_size=3, bias=False, padding=1
        )

    def forward(self, x):
        # x = x.permute(0,2,1)
        x = self.conv(x)
        return x


x = torch.sin(torch.arange(288) / 100).reshape(1, -1, 2)
m = FreqLinear(in_channels=x.shape[-1], fft_len=x.shape[1]//2+1, kernal_size=10)
for p in m.parameters():
    print(p.numel())
y = torch.ones_like(x)
loss_fn = torch.nn.MSELoss()
loss = loss_fn(y, m(x))
loss.backward()

# x = x.repeat(2,1,1)
# kernel = torch.rand(x.shape[1] // 2 + 1, 3)
# x = torch.fft.rfft(x, dim=1)
# print(x)
# print(kernel)
# print(x.shape)
# print(kernel.shape)
# x * kernel

In [None]:
import torch

a = torch.randn(12).reshape(1, -1, 1).float()
avg = torch.nn.AvgPool1d(3,3)
a_avg = avg(a.permute(0,2,1)).permute(0,2,1)

mean = torch.mean(a, dim=1, keepdim=True)
stdev = torch.sqrt(torch.var(a, dim=1, keepdim=True, unbiased=False) + 1e-5)
a_norm = (a - mean) / stdev
a_norm_avg = avg(a_norm.permute(0,2,1)).permute(0,2,1)
torch.allclose(a_norm_avg*stdev + mean, a_avg)

In [None]:
import torch
a = torch.arange(16).float()

a_fft_norm = torch.fft.rfft(a, norm='ortho')

k1 = 12
k2 = -2

print(a_fft_norm)
b = k1 * a + k2
print(b)
b_fft = a_fft_norm.clone() 
b_fft = b_fft * k1
b_fft[0] += k2 * len(a) ** 0.5 
print(b_fft)
# print(b_fft)
# print(b)
print(torch.fft.irfft(b_fft, norm='ortho'))
# torch.allclose(4*a, torch.fft.irfft(a_fft * 4))

In [None]:
import torch
from src.utils.schedule import linear_schedule, cosine_schedule
import matplotlib.pyplot as plt

_,_,_,a = linear_schedule(1e-4, 1e-2, 288)
plt.plot(a)

In [None]:
import torch

ks = 100
# a = torch.sin(torch.linspace(0, torch.pi * 2, 1000)).float()
# a = a + torch.randn_like(a) * 0.1
a = torch.arange(100000).float()
a_std, a_mean = torch.std_mean(a)
print(a_std, a_mean)
a = torch.nn.functional.avg_pool1d(a.reshape(1, 1, -1), ks, 1).flatten()
a_std, a_mean = torch.std_mean(a)
print(a_std, a_mean)

In [None]:
import torch
import matplotlib.pyplot as plt

N = 1000
a = torch.sin(torch.linspace(0, 100, N)) + torch.exp(torch.linspace(0, 100, N) / 50)

plt.plot(a)
a_fft = torch.fft.rfft(a)
print(a_fft[0])
# plt.plot(a_fft.real)
# plt.plot(a_fft.imag)
noise = torch.randn(N)
a = a + noise
plt.plot(a)
a_fft = torch.fft.rfft(a)
print(a_fft[0])
print(torch.fft.rfft(noise)[0])
# plt.plot(a_fft.real)
# plt.plot(a_fft.imag)

In [None]:
from torch.utils.data import DataLoader
from src.datamodule.data_loader import Dataset_Custom, Dataset_ETT_hour

dl = Dataset_ETT_hour(
    None,
    root_path="/home/user/data/THU-timeseries/ETT-small/",
    # root_path="/home/user/data/FrequencyDiffusion/dataset/",
    data_path="ETTh1.csv",
    # data_path="MFRED.csv",
    flag="test",
    size=[96, 48, 96],
    freq="h", scale=False
)
dl = DataLoader(dl, batch_size=128, shuffle=False, drop_last=True)
for batch in dl:
    seq_x, seq_y, seq_x_mark, seq_y_mark = batch
    print(seq_x.shape)
    print(seq_x[0, :10])
    print(seq_y[0, -96:])
    print(seq_x_mark[0,:10])
    break

In [81]:
import torch

seq_len = 8
kernel_size = 3
a = torch.randn(seq_len).reshape(1, 1, seq_len)
avg = torch.nn.AvgPool1d(kernel_size=kernel_size, stride=1)
a_avg = avg(a)
print(a_avg)
a_interp = torch.nn.functional.interpolate(a_avg, size=seq_len, mode='nearest-exact')
a_interp

tensor([[[ 0.1387,  0.1718, -0.1834,  1.0032,  0.4488,  1.3894]]])


tensor([[[ 0.1387,  0.1718,  0.1718, -0.1834,  1.0032,  0.4488,  0.4488,
           1.3894]]])

In [83]:
conv = torch.zeros(seq_len, seq_len-kernel_size + 1)
start = 0
for i in range(conv.shape[1]):
    end = start + kernel_size
    conv[start:end, i] = 1/kernel_size
    start += 1
conv = conv.reshape(1, seq_len, seq_len-kernel_size + 1)
conv_interp = torch.nn.functional.interpolate(conv, size=seq_len, mode='nearest-exact').squeeze().T

# torch.einsum('tl,btc->blc', conv_interp.squeeze(), a.permute(0,2,1))
conv_interp @ a.permute(0,2,1)

tensor([[[ 0.1387],
         [ 0.1718],
         [ 0.1718],
         [-0.1834],
         [ 1.0032],
         [ 0.4488],
         [ 0.4488],
         [ 1.3894]]])

In [85]:
torch.eye(seq_len) - conv_interp @ conv_interp.T

tensor([[ 0.6667, -0.2222, -0.2222, -0.1111,  0.0000,  0.0000,  0.0000,  0.0000],
        [-0.2222,  0.6667, -0.3333, -0.2222, -0.1111,  0.0000,  0.0000,  0.0000],
        [-0.2222, -0.3333,  0.6667, -0.2222, -0.1111,  0.0000,  0.0000,  0.0000],
        [-0.1111, -0.2222, -0.2222,  0.6667, -0.2222, -0.1111, -0.1111,  0.0000],
        [ 0.0000, -0.1111, -0.1111, -0.2222,  0.6667, -0.2222, -0.2222, -0.1111],
        [ 0.0000,  0.0000,  0.0000, -0.1111, -0.2222,  0.6667, -0.3333, -0.2222],
        [ 0.0000,  0.0000,  0.0000, -0.1111, -0.2222, -0.3333,  0.6667, -0.2222],
        [ 0.0000,  0.0000,  0.0000,  0.0000, -0.1111, -0.2222, -0.2222,  0.6667]])

In [37]:
import torch
a = torch.randn(2, 96, 1)
a_mean = a.mean(dim=1, keepdim=True)
print(a_mean.shape)
noise = torch.randn((3, *a_mean.shape))
print(noise.shape)
a_mean_noise = a_mean + noise
print(a_mean_noise)

torch.Size([2, 1, 1])
torch.Size([3, 2, 1, 1])
tensor([[[[ 0.0513]],

         [[ 1.0238]]],


        [[[-0.4672]],

         [[-0.8953]]],


        [[[-0.0870]],

         [[ 0.3308]]]])


In [18]:
import torch
a = torch.randn(2, 1, 2)
a = a.expand(-1,4,-1)
b = torch.randn(2, *a.shape)
a = a + b
print(a)
a.flatten(end_dim=1)

tensor([[[[ 0.6473,  1.6876],
          [ 0.7649, -0.7153],
          [ 0.6628, -2.2105],
          [-1.0683, -1.4461]],

         [[-0.8408,  3.9166],
          [ 0.7689,  3.3959],
          [ 1.4035,  2.3657],
          [-1.0246,  3.7916]]],


        [[[-0.7536, -1.3219],
          [ 2.1606, -0.1201],
          [ 2.0112, -0.4738],
          [-0.1889, -1.1185]],

         [[ 1.5766,  2.7478],
          [ 0.0160,  3.2101],
          [ 1.1320,  2.0315],
          [-1.3659,  3.1250]]]])


tensor([[[ 0.6473,  1.6876],
         [ 0.7649, -0.7153],
         [ 0.6628, -2.2105],
         [-1.0683, -1.4461]],

        [[-0.8408,  3.9166],
         [ 0.7689,  3.3959],
         [ 1.4035,  2.3657],
         [-1.0246,  3.7916]],

        [[-0.7536, -1.3219],
         [ 2.1606, -0.1201],
         [ 2.0112, -0.4738],
         [-0.1889, -1.1185]],

        [[ 1.5766,  2.7478],
         [ 0.0160,  3.2101],
         [ 1.1320,  2.0315],
         [-1.3659,  3.1250]]])

In [17]:
a.reshape(4,4,2)

tensor([[[ 0.8440,  0.3983],
         [ 1.8079,  3.6293],
         [ 2.4288,  1.1355],
         [ 1.2998,  3.5468]],

        [[ 1.4144, -1.9098],
         [ 2.0635, -1.3511],
         [ 2.7374, -0.7609],
         [ 2.8206, -0.6089]],

        [[-0.0300,  2.0117],
         [ 2.0565,  1.9444],
         [ 2.4058,  4.1978],
         [ 1.8671,  0.9327]],

        [[ 3.4994, -0.6885],
         [ 1.2738, -1.0749],
         [ 1.1243, -1.8621],
         [ 0.6856, -1.0254]]])