## Masked Autoencoders: Visualization Demo

This is a visualization demo using our pre-trained MAE models. No GPU is needed.

In [2]:
import sys
import os
import requests
import torchaudio
from torchaudio.compliance import kaldi
import numpy as np
import torch
import matplotlib.pyplot as plt
from PIL import Image
sys.path.append("../")
import importlib
import models_mae
import librosa
import librosa.display
import importlib

  from .autonotebook import tqdm as notebook_tqdm


### Define utils

In [3]:
MELBINS=128
TARGET_LEN=1024
def prepare_model(chkpt_dir, arch='mae_vit_base_patch16'):
    # build model
    model = getattr(models_mae, arch)(in_chans=1, audio_exp=True,img_size=(1024,128),decoder_mode=0)
    # load model
    checkpoint = torch.load(chkpt_dir, map_location='cpu')
    msg = model.load_state_dict(checkpoint['model'], strict=False)
    print(msg)
    return model
def prepare_model1(chkpt_dir, arch='mae_vit_base_patch16'):
    # build model
    model = getattr(models_mae, arch)(in_chans=1, audio_exp=True,img_size=(1024,128),decoder_mode=1,decoder_depth=16)
    # load model
    checkpoint = torch.load(chkpt_dir, map_location='cpu')
    msg = model.load_state_dict(checkpoint['model'], strict=False)
    print(msg)
    return model
def wav2fbank(filename):

    waveform, sr = torchaudio.load(filename)
    waveform = waveform - waveform.mean()

    # 498 128
    fbank = kaldi.fbank(waveform, htk_compat=True, sample_frequency=sr, use_energy=False, 
                        window_type='hanning', num_mel_bins=MELBINS, dither=0.0, frame_shift=10)
    # AudioSet: 1024 (16K sr)
    n_frames = fbank.shape[0]
    p = TARGET_LEN - n_frames
    # cut and pad
    if p > 0:
        m = torch.nn.ZeroPad2d((0, 0, 0, p))
        fbank = m(fbank)
    elif p < 0:
        fbank = fbank[0:TARGET_LEN, :]
    return fbank
def norm_fbank(fbank):
    norm_mean= -4.2677393
    norm_std= 4.5689974
    fbank = (fbank - norm_mean) / (norm_std * 2)
    return fbank
def display_fbank(bank, minmin=None, maxmax=None):
    #print(bank.shape, bank.min(), bank.max())
    #plt.figure(figsize=(18, 6))
    #plt.figure(figsize=(20, 4))
    plt.imshow(20*bank.T.numpy(), origin='lower', interpolation='nearest', vmax=maxmax, vmin=minmin,  aspect='auto')
    #plt.colorbar()
    #S_db = librosa.amplitude_to_db(np.abs(bank.T.numpy()),ref=np.max)
    #S_db = bank.T.numpy()
    #plt.figure()
    #librosa.display.specshow(10*bank.T.numpy())
    #plt.colorbar()

In [4]:
importlib.reload(models_mae)

import pathlib
temp = pathlib.PosixPath
pathlib.PosixPath = pathlib.WindowsPath

In [5]:
importlib.reload(models_mae)
#chkpt_dir = '/checkpoint/berniehuang/experiments/53417041/checkpoint-80.pth'
#chkpt_dir = '/checkpoint/berniehuang/experiments/55951690/checkpoint-20.pth' #(TF-mask AMAE, 0.7. 0.3)
#chkpt_dir = '/checkpoint/berniehuang/experiments/55986074/checkpoint-20.pth' #(TF-mask AMAE) (0.5, 0.2)
#chkpt_dir = '/checkpoint/berniehuang/experiments/55986075/checkpoint-20.pth' #(TF-mask AMAE) (0.2, 0.1)
#chkpt_dir = '/checkpoint/berniehuang/experiments/54463265/checkpoint-28.pth' # random AMAE
#chkpt_dir = '/checkpoint/berniehuang/experiments/55986072/checkpoint-32.pth' # random AMAE (new)
#chkpt_dir = '/checkpoint/berniehuang/experiments/56067384/checkpoint-28.pth' # random AMAE (new)
#chkpt_dir = '/checkpoint/berniehuang/experiments/56373517/checkpoint-24.pth' # random AMAE, decoder=4, norm_pxl=False
chkpt_dir = r'D:\AudioMAE\pretrained.pth' # random AMAES, decoder=4, norm_pxl=False
model = prepare_model1(chkpt_dir, 'mae_vit_base_patch16')
#model = prepare_model1(chkpt_dir, 'mae_vit_base_patch16')
#model = prepare_model(chkpt_dir, 'amvmae_vit_base_patch16')
print('Model loaded.')

512
PatchEmbed_org(
  (proj): Conv2d(1, 768, kernel_size=(16, 16), stride=(16, 16))
)


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


<All keys matched successfully>
Model loaded.


In [6]:
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

def run_one_audio(wav_file, model):
    fbank = wav2fbank(wav_file)
    fbank = norm_fbank(fbank)
    x = torch.tensor(fbank)
    x = x.unsqueeze(0)
    x = x.unsqueeze(0)
    mask_ratio = 0.3
    _, y, mask, _ = model(x.float(), mask_ratio=mask_ratio)
    y_unpatch = model.unpatchify(y)
    y_unpatch = torch.einsum('nchw->nhwc', y_unpatch).detach().cpu()
    # visualize the mask
    mask = mask.detach()
    #mask = mask.unsqueeze(-1).repeat(1, 1, model.patch_embed.patch_size[0]**2 *3)  # (N, H*W, p*p*3)
    mask = mask.unsqueeze(-1).repeat(1, 1, model.patch_embed.patch_size[0]**2 *1)
    mask = model.unpatchify(mask)  # 1 is removing, 0 is keeping
    mask = torch.einsum('nchw->nhwc', mask).detach().cpu()
    #print(x.shape, mask.shape)
    # masked image
    x = torch.einsum('nchw->nhwc', x)
    im_masked = x * (1 - mask)
    im_paste = x * (1 - mask) + y * mask
    
    minmin=-5
    maxmax=10
    start=150
    end=800
    
    # make the plt figure larger
    plt.rcParams['figure.figsize'] = [24, 12]
    plt.subplot(3, 1, 1)
    display_fbank(x[0][start:end].squeeze(), minmin=minmin, maxmax=maxmax)
    #plt.show()

    plt.subplot(3, 1, 2)
    display_fbank(im_masked[0][start:end].squeeze(),minmin=minmin, maxmax=maxmax)
    #plt.show()
    
    plt.subplot(3, 1, 3)
    display_fbank((y_unpatch[0][start:end]).squeeze(), minmin=minmin, maxmax=maxmax)
    #plt.show()
    
    print(x.shape)
    print(y.shape)
    print(mask.shape)
    im_paste = x * (1 - mask) + y * mask
    display_fbank(im_paste[0][start:end].squeeze(),minmin=minmin, maxmax=maxmax)
    
    if model.mask_2d:
        fn=os.path.basename(wav_file).replace('.wav',f'_2d_{model.mask_t_prob}_{model.mask_f_prob}.pdf')
    else:
        fn=os.path.basename(wav_file).replace('.wav',f'_{mask_ratio}.pdf')
    fn=os.path.join('/checkpoint/berniehuang/mae/vis',fn)
    plt.savefig(fn)

In [None]:
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

wav_file=r"D:\AudioMAE\Beijing Police car siren-[AudioTrimmer.com].wav"
fbank = wav2fbank(wav_file)
fbank = norm_fbank(fbank)
x = torch.tensor(fbank)
x = x.unsqueeze(0)
x = x.unsqueeze(0)
mask_ratio = 0.3
_, y, mask, _ = model(x.float(), mask_ratio=mask_ratio)
y_unpatch = model.unpatchify(y)
y_unpatch = torch.einsum('nchw->nhwc', y_unpatch).detach().cpu()
# visualize the mask
mask = mask.detach()
#mask = mask.unsqueeze(-1).repeat(1, 1, model.patch_embed.patch_size[0]**2 *3)  # (N, H*W, p*p*3)
mask = mask.unsqueeze(-1).repeat(1, 1, model.patch_embed.patch_size[0]**2 *1)
mask = model.unpatchify(mask)  # 1 is removing, 0 is keeping
mask = torch.einsum('nchw->nhwc', mask).detach().cpu()
#print(x.shape, mask.shape)
# masked image
x = torch.einsum('nchw->nhwc', x)
im_masked = x * (1 - mask)

minmin=-5
maxmax=10
start=150
end=800

# make the plt figure larger
plt.rcParams['figure.figsize'] = [24, 12]
plt.subplot(3, 1, 1)
display_fbank(x[0][start:end].squeeze(), minmin=minmin, maxmax=maxmax)
#plt.show()

plt.subplot(3, 1, 2)
display_fbank(im_masked[0][start:end].squeeze(),minmin=minmin, maxmax=maxmax)
#plt.show()

plt.subplot(3, 1, 3)
display_fbank((y_unpatch[0][start:end]).squeeze(), minmin=minmin, maxmax=maxmax)
#plt.show()

print(x.shape)
print(y.shape)
print(mask.shape)
im_paste = x * (1 - mask) + y * mask
display_fbank(im_paste[0][start:end].squeeze(),minmin=minmin, maxmax=maxmax)

if model.mask_2d:
    fn=os.path.basename(wav_file).replace('.wav',f'_2d_{model.mask_t_prob}_{model.mask_f_prob}.pdf')
else:
    fn=os.path.basename(wav_file).replace('.wav',f'_{mask_ratio}.pdf')
fn=os.path.join('.',fn)
plt.savefig(fn)

In [16]:
y.shape

torch.Size([1, 512, 256])

In [17]:
x[0][start:end].shape

torch.Size([650, 128, 1])

In [None]:
# wav_file = 

()

In [None]:
wav_file1 = '/large_experiments/cmd/audioset/balance_wav/zye7IPXojSc.wav'
wav_file2='/large_experiments/cmd/audioset/balance_wav/zyqg4pYEioQ.wav'
wav_file3='/large_experiments/cmd/audioset/eval_wav/1W2FOzSXsxs.wav'
wav_file4='/large_experiments/cmd/audioset/eval_wav/1SLrRllxMkU.wav'
wav_file5='/large_experiments/cmd/audioset/eval_wav/1FpNkptebK8.wav'
wav_file6='/large_experiments/cmd/audioset/eval_wav/1IrYZhVhN1s.wav'
wav_file7='/large_experiments/cmd/audioset/eval_wav/0q1wOYCfLlQ.wav'
wav_file8='/large_experiments/cmd/audioset/eval_wav/0qDs_aC0LwI.wav'
wav_file9='/large_experiments/cmd/audioset/eval_wav/0qSK2GuljEc.wav'
wav_file0='/large_experiments/cmd/audioset/eval_wav/0qWRXZkmXF8.wav'
wav_file10='/large_experiments/cmd/audioset/eval_wav/MdYXznF3Eac.wav'
wav_file11='/large_experiments/cmd/audioset/eval_wav/Rr84-EZvO0U.wav'
wav_file12='/large_experiments/cmd/audioset/eval_wav/XHQGUbMSPTM.wav'
wav_file13='/large_experiments/cmd/audioset/eval_wav/bq6C0_tAbJM.wav'
wav_file14='/large_experiments/cmd/audioset/eval_wav/hRbukCd6N68.wav'
wav_file15='/large_experiments/cmd/audioset/eval_wav/HV1J_actdHE.wav'
wav_file16='/large_experiments/cmd/audioset/eval_wav/8UMdVUartLw.wav'
wav_file17='/large_experiments/cmd/audioset/eval_wav/3Mo-YFd31rs.wav'
wav_file18='/large_experiments/cmd/audioset/eval_wav/nT_R3O0OK6U.wav'
wav_file19='/large_experiments/cmd/audioset/eval_wav/bvapjUmC7bY.wav'

In [None]:
wav_file0=r"D:\AudioMAE\Beijing Police car siren-[AudioTrimmer.com].wav"

In [None]:
torch.manual_seed(31)
model.mask_2d=True
model.mask_t_prob=0.1
model.mask_f_prob=0.1
run_one_audio(wav_file0, model)

  x = torch.tensor(fbank)


torch.Size([1, 512, 768])
torch.Size([1, 513, 768])


RuntimeError: The size of tensor a (512) must match the size of tensor b (128) at non-singleton dimension 2