In [1]:
import sys
import torch
import numpy as np
from functools import partial

sys.path.append('/data2/romit/alan/MusicDacVAE')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from models.utils.distributions import DiagonalGaussianDistribution

In [3]:
from models.model.dac_vae import DACVAE

In [4]:
model = DACVAE(
    encoder_dim = 64,
    encoder_rates = [2, 4, 5, 8],
    latent_dim = 80,
    decoder_dim = 1536,
    decoder_rates = [8, 5, 4, 2],
    sample_rate = 22050).to("cpu")

for n, m in model.named_modules():
    o = m.extra_repr()
    p = sum([np.prod(p.size()) for p in m.parameters()])
    fn = lambda o, p: o + f" {p/1e6:<.3f}M params."
    setattr(m, "extra_repr", partial(fn, o=o, p=p))
print(model)
print("Total # of params: ", sum([np.prod(p.size()) for p in model.parameters()]))


length = 16000 * 4
x = torch.randn(2, 1, length).to(model.device)
x.requires_grad_(True)
x.retain_grad()

# Make a forward pass
out = model(x)
print("Input shape:", x.shape)
print("Output shape:", out['audio'].shape)
print("z shape:", out['z'].shape)
print("loss_KLD:", out['loss_KLD'].shape)

DACVAE(
   61.050M params.
  (encoder): Encoder(
     18.866M params.
    (block): Sequential(
       18.866M params.
      (0): Conv1d(1, 64, kernel_size=(7,), stride=(1,), padding=(3,) 0.001M params.)
      (1): EncoderBlock(
         0.133M params.
        (block): Sequential(
           0.133M params.
          (0): ResidualUnit(
             0.033M params.
            (block): Sequential(
               0.033M params.
              (0): Snake1d( 0.000M params.)
              (1): Conv1d(64, 64, kernel_size=(7,), stride=(1,), padding=(3,) 0.029M params.)
              (2): Snake1d( 0.000M params.)
              (3): Conv1d(64, 64, kernel_size=(1,), stride=(1,) 0.004M params.)
            )
          )
          (1): ResidualUnit(
             0.033M params.
            (block): Sequential(
               0.033M params.
              (0): Snake1d( 0.000M params.)
              (1): Conv1d(64, 64, kernel_size=(7,), stride=(1,), padding=(9,), dilation=(3,) 0.029M params.)
            

In [5]:
from data.dataset_slakh import dataset_slakh2100

In [6]:
dataset = dataset_slakh2100(
    meta_data_path='/data2/romit/alan/MusicDacVAE/data/metadata/slakh_metadata_train_segment_3_shift_1_samples_500000_rms_0.1_0.16_snr_10_30.csv', 
    sample_rate=22050,
    segment_length=22050*3)

In [9]:
dataset[0].shape

torch.Size([66150])

In [10]:
len(dataset)

500000

In [11]:
import soundfile as sf

In [12]:
sf.write('sample.wav', dataset[0], 22050)