In [1]:
%load_ext autoreload
%autoreload 2

import torch

from moment.utils.masking import Masking
from moment.models.layers.revin import RevIN

In [4]:
batch_size = 16
n_patches = 4
d_model = 3

input_mask_patch_view = torch.rand((batch_size, n_patches))
print(input_mask_patch_view.shape)
print(input_mask_patch_view)

expanded_mask = input_mask_patch_view.unsqueeze(-1).repeat(1, 1, d_model)
print(expanded_mask.shape)
print(expanded_mask)

torch.Size([16, 4])
tensor([[0.7891, 0.5462, 0.3527, 0.4715],
        [0.1525, 0.5782, 0.2720, 0.0549],
        [0.2056, 0.4058, 0.1622, 0.1487],
        [0.0691, 0.4874, 0.7084, 0.6822],
        [0.8322, 0.3362, 0.0359, 0.9756],
        [0.6620, 0.0850, 0.5008, 0.2533],
        [0.3823, 0.8053, 0.0732, 0.1288],
        [0.6464, 0.1721, 0.9925, 0.7255],
        [0.3247, 0.0348, 0.9719, 0.6063],
        [0.7039, 0.7127, 0.8540, 0.9093],
        [0.0681, 0.9114, 0.6846, 0.2366],
        [0.3610, 0.8556, 0.5669, 0.2921],
        [0.9284, 0.3298, 0.6816, 0.4342],
        [0.0929, 0.8676, 0.3866, 0.3630],
        [0.8581, 0.6999, 0.6013, 0.2328],
        [0.7122, 0.2464, 0.4994, 0.4607]])
torch.Size([16, 4, 3])
tensor([[[0.7891, 0.7891, 0.7891],
         [0.5462, 0.5462, 0.5462],
         [0.3527, 0.3527, 0.3527],
         [0.4715, 0.4715, 0.4715]],

        [[0.1525, 0.1525, 0.1525],
         [0.5782, 0.5782, 0.5782],
         [0.2720, 0.2720, 0.2720],
         [0.0549, 0.0549, 0.0549]],



In [7]:
batch_size = 5
seq_len = 8
patch_len = 4
n_channels = 2

x_enc = torch.rand((batch_size, n_channels, seq_len))
input_mask = torch.ones((batch_size, seq_len))
mask_obj = Masking(mask_ratio=0.3, patch_len=patch_len, stride=patch_len)
generated_mask = mask_obj.generate_mask(x_enc, input_mask=input_mask)
print(generated_mask.shape)

torch.Size([5, 8])


In [5]:
generated_mask_in_patch_view = Masking.convert_seq_to_patch_view(generated_mask)
print(generated_mask_in_patch_view.shape)

torch.Size([8, 64])


In [6]:
generated_mask_in_seq_view = Masking.convert_patch_to_seq_view(generated_mask_in_patch_view)
print(generated_mask_in_seq_view.shape)

torch.Size([8, 512])


In [7]:
assert torch.all(generated_mask == generated_mask_in_seq_view)

### RevIN

In [8]:
batch_size = 16
seq_len = 512
patch_len = 8
n_channels = 1

input_mask = torch.ones((batch_size, seq_len)) 
x_enc = torch.rand((batch_size, n_channels, seq_len))
print(f"x_enc.shape: {x_enc.shape}")
mask_obj = Masking(mask_ratio=0.3, patch_len=patch_len, stride=patch_len)
generated_mask = mask_obj.generate_mask(x_enc, input_mask=input_mask)
print(f"generated_mask.shape: {generated_mask.shape}")

x_enc.shape: torch.Size([16, 1, 512])
generated_mask.shape: torch.Size([16, 512])


In [23]:
generated_mask

tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        ...,
        [0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 0, 0, 0]])

In [24]:
from moment.utils.data import nanvar

In [37]:
revin_obj = RevIN(num_features=n_channels, affine=False)
x_enc_norm = revin_obj(x_enc, mask=generated_mask, mode="norm")
masked_x_enc = torch.where(generated_mask.unsqueeze(1).bool(), x_enc, torch.nan)

# print("Before reversible instance normalization:", x_enc.shape, 
#       '\n', x_enc.squeeze().mean(axis=-1), '\n', x_enc.squeeze().var(axis=-1))
print("(Masked) Before reversible instance normalization:", masked_x_enc.shape, 
      '\n', masked_x_enc.squeeze().nanmean(axis=-1), '\n', nanvar(masked_x_enc.squeeze(), dim=-1))

masked_x_enc_norm = torch.where(generated_mask.unsqueeze(1).bool(), x_enc_norm, torch.nan)
# print("After reversible instance normalization:", x_enc_norm.shape, 
#       '\n', x_enc_norm.squeeze().nanmean(dim=-1), '\n', x_enc_norm.squeeze().var(axis=-1))
print("(Masked) After reversible instance normalization:", masked_x_enc_norm.shape, 
      '\n', masked_x_enc_norm.squeeze().nanmean(axis=-1), '\n', nanvar(masked_x_enc_norm.squeeze(), dim=-1))

x_enc_denorm = revin_obj(x_enc_norm, mode="denorm")
masked_x_enc_denorm = torch.where(generated_mask.unsqueeze(1).bool(), x_enc_denorm, torch.nan)
# print("After reversible instance denormalization:", x_enc_denorm.shape, '\n', 
#       x_enc_denorm.squeeze().mean(axis=-1), '\n', x_enc_denorm.squeeze().var(axis=-1))
print("(Masked) After reversible instance denormalization:", x_enc_denorm.shape, '\n', 
      '\n', masked_x_enc_denorm.squeeze().nanmean(axis=-1), '\n', nanvar(masked_x_enc_denorm.squeeze(), dim=-1))

(Masked) Before reversible instance normalization: torch.Size([16, 1, 512]) 
 tensor([0.5089, 0.5052, 0.5021, 0.5068, 0.4967, 0.5139, 0.4883, 0.5045, 0.5111,
        0.5041, 0.5069, 0.5228, 0.5326, 0.5072, 0.4681, 0.4822]) 
 tensor([0.0818, 0.0828, 0.0767, 0.0804, 0.0785, 0.0832, 0.0773, 0.0774, 0.0782,
        0.0805, 0.0851, 0.0795, 0.0876, 0.0806, 0.0739, 0.0895])
(Masked) After reversible instance normalization: torch.Size([16, 1, 512]) 
 tensor([-5.1657e-08,  8.2122e-08, -1.3245e-07,  0.0000e+00, -8.8745e-08,
        -6.6227e-08,  1.8279e-07,  1.1921e-08, -1.9868e-08, -3.9736e-09,
        -6.6227e-08,  7.1526e-08, -7.5499e-08,  1.3245e-09,  8.7420e-08,
         1.6689e-07]) 
 tensor([0.9999, 0.9999, 0.9999, 0.9999, 0.9999, 0.9999, 0.9999, 0.9999, 0.9999,
        0.9999, 0.9999, 0.9999, 0.9999, 0.9999, 0.9999, 0.9999])
(Masked) After reversible instance denormalization: torch.Size([16, 1, 512]) 
 
 tensor([0.5089, 0.5052, 0.5021, 0.5068, 0.4967, 0.5139, 0.4883, 0.5045, 0.5111,
    

In [7]:
revin_obj = RevIN(num_features=n_channels, affine=True)
x_enc_norm = revin_obj(x_enc, mask=generated_mask, mode="norm")

print("Before reversible instance normalization:", x_enc.shape, x_enc.mean(), x_enc.var())
print("After reversible instance normalization:", x_enc_norm.shape, x_enc_norm.mean(), x_enc_norm.var())

x_enc_denorm = revin_obj(x_enc_norm, mode="denorm")
print("After reversible instance denormalization:", x_enc_denorm.shape, x_enc_denorm.mean(), x_enc_denorm.var())

Before reversible instance normalization: torch.Size([8, 3, 512]) tensor(0.5017) tensor(0.0830)
After reversible instance normalization: torch.Size([8, 3, 512]) tensor(0.0029, grad_fn=<MeanBackward0>) tensor(0.9950, grad_fn=<VarBackward0>)
After reversible instance denormalization: torch.Size([8, 3, 512]) tensor(0.5017, grad_fn=<MeanBackward0>) tensor(0.0830, grad_fn=<VarBackward0>)


In [10]:
Masking.convert_seq_to_patch_view(generated_mask, patch_len=patch_len)

tensor([[1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1,
         1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0,
         0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0],
        [0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1,
         1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1,
         0, 1, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0],
        [0, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1,
         1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1,
         0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1],
        [1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0,
         1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0,
         0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1,
         1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 

In [11]:
from torch import nn
mask_embedding = nn.Parameter(torch.zeros(768))
nn.init.trunc_normal_(mask_embedding, mean=0.0, std=.02)
value_embedding = nn.Linear(patch_len, 768, bias=True)

n_channels = 3

In [12]:
print(mask_embedding)

Parameter containing:
tensor([-5.6563e-03,  1.1400e-02,  3.8069e-02, -2.7084e-02,  6.1327e-03,
        -3.8223e-03,  9.5821e-03,  9.1625e-03, -7.8339e-03, -2.7275e-02,
        -1.9771e-04,  1.7950e-02,  1.8530e-02,  1.4511e-03,  9.0782e-03,
        -9.0193e-03,  8.3833e-03,  5.5514e-03, -1.7480e-02,  5.3386e-03,
        -1.1893e-02,  3.7317e-03, -1.3266e-02,  8.8647e-03,  3.4053e-02,
         3.5585e-02, -7.9008e-03,  4.0723e-02, -2.1418e-03, -1.1216e-03,
        -1.6082e-02, -3.5233e-02,  8.3909e-03, -1.7249e-02,  5.7902e-03,
         1.4509e-02, -9.2847e-03, -1.3128e-02,  1.4970e-02, -8.8669e-03,
        -9.5377e-03,  1.1817e-02,  2.5525e-02,  5.4501e-03, -4.0365e-03,
         1.9814e-03, -9.6465e-03,  1.7635e-02, -6.6043e-03, -1.3949e-04,
        -5.4365e-03,  5.9173e-03, -2.0668e-02,  2.6240e-03, -6.0574e-03,
        -3.7557e-02, -2.5533e-02, -7.9768e-03, -7.7971e-03, -1.0771e-03,
        -8.2750e-03,  1.0797e-02,  1.7825e-02, -8.5885e-04, -4.8601e-03,
        -1.0844e-03, -3.4335e

In [13]:
print(generated_mask.shape)

torch.Size([8, 512])


In [14]:
x = torch.rand((batch_size, n_channels, 64, patch_len))

In [15]:
mask = Masking.convert_seq_to_patch_view(generated_mask, patch_len=patch_len).unsqueeze(-1)
print(mask.shape)

torch.Size([8, 64, 1])


In [16]:
mask = mask.repeat_interleave(768, dim=-1).unsqueeze(1).repeat(1, n_channels, 1, 1)
print(mask.shape)

torch.Size([8, 3, 64, 768])


In [17]:
# x = x.reshape((x.shape[0] * n_channels, x.shape[2], x.shape[3]))
# mask = mask.reshape((mask.shape[0] * n_channels, mask.shape[2], mask.shape[3]))

In [7]:
batch_size = 5
seq_len = 16
patch_len = 4
n_channels = 1

x_enc = torch.rand((batch_size, n_channels, seq_len))

mask_obj = Masking(mask_ratio=0.3, patch_len=patch_len, stride=patch_len)
generated_mask = mask_obj.generate_mask(x_enc)
generated_mask_1 = mask_obj.generate_mask(x_enc)

In [9]:
masked_x = torch.where(generated_mask.bool(), x_enc, torch.nan)

In [11]:
generated_mask

tensor([[1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0]])

In [10]:
masked_x

tensor([[[0.5564, 0.0124, 0.9068, 0.0675,    nan,    nan,    nan,    nan,
          0.2088, 0.3609, 0.7477, 0.3044,    nan,    nan,    nan,    nan],
         [0.5564, 0.0124, 0.9068, 0.0675, 0.1514, 0.8795, 0.4807, 0.9743,
          0.2088, 0.3609, 0.7477, 0.3044,    nan,    nan,    nan,    nan],
         [   nan,    nan,    nan,    nan,    nan,    nan,    nan,    nan,
          0.2088, 0.3609, 0.7477, 0.3044, 0.7728, 0.5436, 0.1199, 0.0566],
         [   nan,    nan,    nan,    nan, 0.1514, 0.8795, 0.4807, 0.9743,
             nan,    nan,    nan,    nan, 0.7728, 0.5436, 0.1199, 0.0566],
         [0.5564, 0.0124, 0.9068, 0.0675, 0.1514, 0.8795, 0.4807, 0.9743,
          0.2088, 0.3609, 0.7477, 0.3044,    nan,    nan,    nan,    nan]],

        [[0.2979, 0.9697, 0.1694, 0.4599,    nan,    nan,    nan,    nan,
          0.6839, 0.0991, 0.4961, 0.0106,    nan,    nan,    nan,    nan],
         [0.2979, 0.9697, 0.1694, 0.4599, 0.1915, 0.2373, 0.0655, 0.7039,
          0.6839, 0.0991, 0.49

In [64]:
print(generated_mask, f'\n{generated_mask_1}')

tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
        [0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1]]) 
tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0]])


In [65]:
generated_mask*generated_mask_1

tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0]])

In [36]:
print(x_enc.shape, generated_mask.shape)

torch.Size([5, 1, 16]) torch.Size([5, 16])


In [37]:
generated_mask

tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1],
        [1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0],
        [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1]])

In [38]:
torch.masked_select(input=x_enc, mask=generated_mask.unsqueeze(1).repeat((1,n_channels,1)).bool())

tensor([0.0125, 0.2545, 0.2849, 0.0521, 0.5358, 0.6295, 0.6922, 0.1970, 0.4741,
        0.2726, 0.4006, 0.8102, 0.7616, 0.4026, 0.9062, 0.7949, 0.9387, 0.1628,
        0.3855, 0.8334, 0.6171, 0.0557, 0.8788, 0.0384, 0.1543, 0.0154, 0.1122,
        0.1321, 0.6690, 0.6656, 0.5286, 0.6932, 0.4928, 0.8596, 0.3379, 0.2115,
        0.9073, 0.5408, 0.6029, 0.0463, 0.5339, 0.4061, 0.4120, 0.9525, 0.5427,
        0.4730, 0.5600, 0.2020, 0.1492, 0.9209, 0.4773, 0.8560])

In [46]:
masked_tensor = torch.where(generated_mask.unsqueeze(1).repeat((1,n_channels,1)).bool(), x_enc, torch.nan)
print(masked_tensor.shape, f'\n{masked_tensor}')

torch.Size([5, 1, 16]) 
tensor([[[0.0125, 0.2545, 0.2849, 0.0521, 0.5358, 0.6295, 0.6922, 0.1970,
          0.4741, 0.2726, 0.4006, 0.8102, 0.7616, 0.4026, 0.9062, 0.7949]],

        [[0.9387, 0.1628, 0.3855, 0.8334, 0.6171, 0.0557, 0.8788, 0.0384,
          0.1543, 0.0154, 0.1122, 0.1321, 0.6690, 0.6656, 0.5286, 0.6932]],

        [[   nan,    nan,    nan,    nan,    nan,    nan,    nan,    nan,
             nan,    nan,    nan,    nan, 0.4928, 0.8596, 0.3379, 0.2115]],

        [[0.9073, 0.5408, 0.6029, 0.0463,    nan,    nan,    nan,    nan,
          0.5339, 0.4061, 0.4120, 0.9525,    nan,    nan,    nan,    nan]],

        [[0.5427, 0.4730, 0.5600, 0.2020,    nan,    nan,    nan,    nan,
             nan,    nan,    nan,    nan, 0.1492, 0.9209, 0.4773, 0.8560]]])
