In [70]:
import torchaudio
import torch
import torch.nn as nn
from util.patch_embed import PatchEmbed_org
from model import Block

In [71]:
a = torch.randn(2, 10*16000)

In [72]:
sample_rate = 16000
win_length = int(sample_rate * 0.025)  # 25ms
hop_length = int(sample_rate * 0.01)  # 10ms
transform = torchaudio.transforms.MelSpectrogram(
    sample_rate=sample_rate,
    win_length=win_length,
    hop_length=hop_length,
    n_fft=win_length,
    n_mels=128,  # 你可以根据需要调整mel频带的数量
    window_fn=torch.hamming_window
)



In [73]:
b = transform(a).unsqueeze(1)

In [74]:
b.shape

torch.Size([2, 1, 128, 1001])

In [75]:
patch_to_emb = PatchEmbed_org(224, 16, 1, 768)
num_patches = (b.shape[2] // 16) * (b.shape[3] // 16)
pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, 768), requires_grad=False) 

In [76]:
embs = patch_to_emb(b)
input_tokens = embs + pos_embed[:, 1:, :]
cls_token = nn.Parameter(torch.zeros(1, 1, 768)) + pos_embed[:, 1:, :]

In [77]:
cls_tokens = cls_token.expand(input_tokens.shape[0], -1, -1)
x = torch.cat((cls_tokens, input_tokens), dim=1)

In [78]:
blk = Block(768, 16, 4, qkv_bias=True, norm_layer=nn.LayerNorm)

In [79]:
encoder_result = blk(x)

In [80]:
embs.shape

torch.Size([2, 496, 768])

In [81]:
dummy = torch.rand(2, 5, 5)
noise = torch.rand(2, 5)
ids_shuffle = torch.argsort(noise, dim=1)
ids_restore = torch.argsort(ids_shuffle, dim=1)

In [82]:
print(dummy)
print(noise)
print(ids_shuffle)
print(ids_restore)

tensor([[[0.6453, 0.2386, 0.8257, 0.8505, 0.9747],
         [0.7550, 0.5414, 0.8062, 0.6918, 0.1312],
         [0.2928, 0.6910, 0.6827, 0.0190, 0.1402],
         [0.9646, 0.6618, 0.2535, 0.8901, 0.2351],
         [0.4347, 0.5404, 0.2957, 0.0304, 0.1935]],

        [[0.2689, 0.9827, 0.1391, 0.4593, 0.3023],
         [0.7796, 0.2057, 0.9030, 0.1741, 0.9467],
         [0.2775, 0.2084, 0.3652, 0.3638, 0.2278],
         [0.9628, 0.5409, 0.2111, 0.7085, 0.8050],
         [0.1502, 0.5530, 0.7047, 0.4451, 0.9808]]])
tensor([[0.3775, 0.7035, 0.9517, 0.9874, 0.3406],
        [0.0523, 0.8955, 0.2882, 0.0307, 0.3046]])
tensor([[4, 0, 1, 2, 3],
        [3, 0, 2, 4, 1]])
tensor([[1, 2, 3, 4, 0],
        [1, 4, 2, 0, 3]])


In [83]:
ids_keep = ids_shuffle[:, :2]
dummy_masked = torch.gather(dummy, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, 5))

In [84]:
ids_keep.unsqueeze(-1).repeat(1, 1, 5)

tensor([[[4, 4, 4, 4, 4],
         [0, 0, 0, 0, 0]],

        [[3, 3, 3, 3, 3],
         [0, 0, 0, 0, 0]]])

In [85]:
dummy_masked

tensor([[[0.4347, 0.5404, 0.2957, 0.0304, 0.1935],
         [0.6453, 0.2386, 0.8257, 0.8505, 0.9747]],

        [[0.9628, 0.5409, 0.2111, 0.7085, 0.8050],
         [0.2689, 0.9827, 0.1391, 0.4593, 0.3023]]])

In [None]:
# generate the binary mask: 0 is keep, 1 is remove
mask = torch.ones([2, 5], device=dummy.device)
mask[:, :2] = 0
# unshuffle to get the binary mask
mask = torch.gather(mask, dim=1, index=ids_restore)

In [122]:
mask

tensor([[0., 0., 1., 1., 1.],
        [0., 1., 1., 1., 0.]])

In [94]:
decoder_embed = nn.Linear(5, 5, bias=True)
embs = decoder_embed(dummy_masked)

In [110]:
mask_token = nn.Parameter(torch.zeros(1, 1, 5))
mask_tokens = mask_token.repeat(embs.shape[0], ids_restore.shape[1] + 1 - embs.shape[1], 1)

In [111]:
print(embs.shape)
print(ids_restore.shape)

torch.Size([2, 2, 5])
torch.Size([2, 5])


In [113]:
mask_tokens.shape

torch.Size([2, 4, 5])

In [117]:
x1 = torch.cat([embs[:, :, :], mask_tokens], dim=1)

In [118]:
print(embs)
print(x1)

tensor([[[ 0.1799, -0.6239,  0.4720,  0.2164, -0.1589],
         [ 0.4824, -0.5183,  0.6333,  0.7364, -0.3762]],

        [[ 0.3029, -0.6917,  0.4276,  0.4443, -0.3834],
         [ 0.1936, -0.5829,  0.2732,  0.2881, -0.2181]]],
       grad_fn=<ViewBackward0>)
tensor([[[ 0.1799, -0.6239,  0.4720,  0.2164, -0.1589],
         [ 0.4824, -0.5183,  0.6333,  0.7364, -0.3762],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000]],

        [[ 0.3029, -0.6917,  0.4276,  0.4443, -0.3834],
         [ 0.1936, -0.5829,  0.2732,  0.2881, -0.2181],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000]]],
       grad_fn=<CatBackward0>)


In [119]:
x2 = torch.gather(x1, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))

In [120]:
x2

tensor([[[ 0.4824, -0.5183,  0.6333,  0.7364, -0.3762],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.1799, -0.6239,  0.4720,  0.2164, -0.1589]],

        [[ 0.1936, -0.5829,  0.2732,  0.2881, -0.2181],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.3029, -0.6917,  0.4276,  0.4443, -0.3834],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000]]],
       grad_fn=<GatherBackward0>)

In [1]:
ids_restore

NameError: name 'ids_restore' is not defined