In [1]:
import torchaudio
import torch
import torch.nn as nn
from util.patch_embed import PatchEmbed_org
from models_mae import MaskedAutoencoder

In [2]:
model = MaskedAutoencoder(embed_dim=768, do_mask=False)

In [3]:
def padding(spec_batch, in_chanel, embed_dim, patch_size=16, smallest_length=1024):
    # the default longest length of a spectrogram is 1024
    padded_specs = torch.tensor([])
    embeder = nn.Conv2d(in_chanel, embed_dim, kernel_size=(patch_size, patch_size), stride=(patch_size, patch_size), bias=False)
    embeder.weight.requires_grad = False
    N, C, H, W = spec_batch.shape
    longest = smallest_length

    # find the longest 
    for idx in range(N):
        spec = spec_batch[idx, :, :, :]
        spec_c, spec_h, spec_w = spec.shape
        if spec_w > longest:
            longest = spec_w

    # pad the spectrogram
    for idx in range(N):
        spec = spec_batch[idx, :, :, :]
        spec_c, spec_h, spec_w = spec.shape
        if spec_w < longest:
            pads = torch.zeros(spec_c, spec_h, longest - spec_w)
            padded_spec = torch.cat([spec, pads], dim=-1).unsqueeze(0)
            padded_specs = torch.cat([padded_specs, padded_spec], dim=0)

    # get the padding mask
    padding_masks = embeder(padded_specs).flatten(2).transpose(1, 2)
    padding_masks = torch.where(padding_masks == 0, 1, 0).bool()

    return padded_specs, padding_masks

In [4]:
a = torch.rand(2, 1, 128, 1001)
a, masks = padding(a, 1, 768, 16, 1024)

In [5]:
print(masks.shape)
print(masks[:, :, :512].shape)
print(masks[:, :, :256].shape)

torch.Size([2, 512, 768])
torch.Size([2, 512, 512])
torch.Size([2, 512, 256])


In [6]:
model(a, padding_mask=masks)

torch.Size([2, 512, 256])


(tensor([[1.5013, 1.5960, 1.5284,  ..., 1.4298, 1.4465, 0.0000],
         [1.5653, 1.5508, 1.5583,  ..., 1.5301, 1.4107, 0.0000]],
        grad_fn=<MeanBackward1>),
 tensor([[[-0.1342, -1.7399,  2.5878,  ...,  1.4804,  1.3534,  1.0617],
          [-0.1890, -1.7116,  2.8103,  ...,  1.5090,  1.4721,  1.0208],
          [-0.2631, -1.5977,  3.0126,  ...,  1.5937,  1.0246,  1.2272],
          ...,
          [-0.1900, -2.7658,  2.5324,  ...,  0.8216,  0.5004,  1.2674],
          [-0.0825, -2.7024,  2.5177,  ...,  0.8202,  0.5130,  1.2409],
          [ 0.0000, -0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],
 
         [[-0.2108, -1.7987,  2.5642,  ...,  1.4829,  1.3316,  0.8886],
          [-0.1336, -1.7770,  2.9069,  ...,  1.5689,  1.3558,  1.1317],
          [-0.2668, -1.5569,  3.0893,  ...,  1.3738,  1.0580,  1.2533],
          ...,
          [-0.1837, -2.8484,  2.5315,  ...,  0.7940,  0.5916,  1.2866],
          [-0.0241, -2.7930,  2.5626,  ...,  0.8455,  0.5201,  1.2673],
          

In [7]:
def patchify(imgs):
    h = imgs.shape[2] // 16
    w = imgs.shape[3] // 16
    x = imgs.reshape(shape=(imgs.shape[0], 1, h, 16, w, 16))
    x = torch.einsum('nchpwq->nhwpqc', x)
    x = x.reshape(shape=(imgs.shape[0], h * w, 16**2 * 1))
    return x

In [8]:
patches = patchify(torch.rand(2, 1, 128, 1024))

In [9]:
patches.shape

torch.Size([2, 512, 256])