In [None]:
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import h5py
import os

from scipy import signal
from scipy import io

In [3]:
import models.models_mae as models_mae

In [5]:
%load_ext autoreload
%autoreload 2

In [None]:
def read_all_seqs(dir='../data', files=None):
    if files is None:
        files = map(lambda filename: dir + '/' + filename, os.listdir(dir))

    seqs = []
    for filename in files:
        with h5py.File(filename, 'r') as f:
            seqs.extend([f[h5py.h5r.get_name(elem, f.id)][:] for elem in f['chirp_sequence_array']])

    return seqs

In [None]:
def get_sample(use_wav_example=false, seq_idx=200):
    if use_wav_example:
        _, Y = io.wavfile.read('./audioset/Y__p-iA312kg.wav')
        return Y
    seqs = read_all_seqs()
    X = seqs[seq_idx].T[:, 0:4]

    upsample_rate = 110 // 15 #match upper range of bat hearing to approx. upper range of human hearing

    x = signal.resample(X[:, 0], X.shape[0] * upsample_rate)
    x = x * signal.windows.hamming(len(x))
    x = np.hstack((x, np.zeros(320299 - len(x_cpu))))

    return x

In [None]:
def mel_spectrogram(self, x, model):
    old_shape = x.size()
    x = x.reshape(-1, old_shape[2])
    x = model.mel(x)
    x = (x - model.frame_mean[None, :, None]) / model.frame_std[None, :, None]
    x = x.reshape(old_shape[0], old_shape[1], x.shape[1], x.shape[2])
    return x

In [None]:
def test_model_forward(x, model, device, mask_ratio=0.75):
    model = model.to(device)

    x = torch.from_numpy(x.reshape(1, 1, -1).astype('float32'))
    x = x.to(device)
    x = x.type(torch.FloatTensor).cuda()

    mel_x = mel_spectrogram(x, model)
    mel_x_cpu = mel_x.detach().cpu()[0][0].numpy()

    loss, pred, mask = model.forward(x, 0.1)
    x_prime = model.unpatchify(pred[:, :, :, np.newaxis])
    
    return mel_x_cpu, x_prime, loss, mask


In [None]:
def plot_specgram(Sx):
    plt.figure(figsize=(15, 4))
    plt.imshow(np.flip(Sx, axis=0), aspect='auto')
    plt.show()

In [None]:
def test_imputation(x, model, device, patch_mask):
    mel_x = mel_spectrogram(x, model)
    mel_x = mel_x[:, :, :model.patch_embed.img_size[0], :model.patch_embed.img_size[1]] 
    mel_x = model.patch_embed(mel_x)
    mel_x = mel_x + model.pos_embed[:, 1:, :]

    D = mel_x.shape[3]

    ids_keep = torch.nonzero(patch_mask, as_tuple=True)[0]
    ids_mask = torch.nonzero(1 - patch_mask, as_tuple=True)[0]
    ids_restore = torch.cat((ids_keep, ids_mask))
    
    mel_x_masked = torch.gather(mel_x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))

    cls_token = model.cls_token + model.pos_embed[:, :1, :]
    cls_tokens = cls_token.expand(mel_x_masked.shape[0], -1, -1)
    mel_x_masked = torch.cat((cls_tokens, mel_x_masked), dim=1)

    # apply Transformer blocks
    for blk in model.blocks:
        x = blk(x)
    x = model.norm(x)

    pred = model.forward_decoder(latent, ids_restore)

In [None]:
## Setup pytorch

GPU_NUM = '0'

device = torch.device(('cuda:' + GPU_NUM) if (torch.cuda.is_available() and GPU_NUM != '-1') else 'cpu')
print("Using device: {}".format(device))
if device == 'cuda':
    print("Device index: {}".format(torch.cuda.current_device()))


In [None]:
model = 

In [None]:
pred = model.forward_decoder(latent, ids_restore)

In [None]:
model.forward_loss(mel_x, pred, mask)

In [None]:
x_prime = model.unpatchify(pred[:, :, :, np.newaxis])