In [4]:
%load_ext autoreload
%autoreload 2

### IMPORTS
import torch
import IPython.display

from music2latent import EncoderDecoder
from torch.utils.data import DataLoader

from src.processing.dataloaders import CachedSimpleDataset
from src.processing.load_music2latent_model import load_music2latent_model
from src.networks.utils import sample_noise

import src.handlers.utils as handlers_utils

DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
torch.set_grad_enabled(False)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


<torch.autograd.grad_mode.set_grad_enabled at 0x7fc66b272f50>

In [5]:
## LOAD DATA & CODEC
emb_model = EncoderDecoder(device=DEVICE)
DATASET = "slakh"

valid_recache_every = 10000 
valid_max_samples = 10000

DATASET_PATH = f"/data/alfred/datasets/{DATASET}/music2latent/validation"
BATCH_SIZE = 16

### Data loading
test_dataset = CachedSimpleDataset(
    DATASET_PATH,
    keys=["waveform", "z", "metadata"],
    max_samples=valid_max_samples,
    recache_every=valid_recache_every)

dataloader_test = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    num_workers=0,
    shuffle=False,
    drop_last=True)


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [00:21<00:00, 464.02it/s]


# HDAE baseline

In [9]:
hierarchical_handler = load_music2latent_model(
    "hdae",
    "music2latent_hdae_01",
    1500,
    DATASET,
    dataloader_test,
    DEVICE,
)

Loading from diffae/src/saved_models/slakh:music2latent_hdae:config_01
--------------- Model Infos ----------------
Diffusion Unet: 21.693184 M Params
Semantic Encoder: 4.11712 M Params


In [12]:
# SLAKH
encoded_audio_01 = handlers_utils.zero_pad(torch.from_numpy(test_dataset[2]["z"]).unsqueeze(0).to(DEVICE), hierarchical_handler.length, DEVICE)[:2].float().repeat(2, 1, 1)
encoded_audio_02 = handlers_utils.zero_pad(torch.from_numpy(test_dataset[50]["z"]).unsqueeze(0).to(DEVICE), hierarchical_handler.length, DEVICE)[:2].float().repeat(2, 1, 1)

z_sem_01 = hierarchical_handler.semantic_encoder(encoded_audio_01)
z_sem_02 = hierarchical_handler.semantic_encoder(encoded_audio_02)

reconstruction_01, _ = hierarchical_handler.encode_decode(encoded_audio_01) 
reconstruction_02, _ = hierarchical_handler.encode_decode(encoded_audio_02) 

In [13]:
print(f"Audio 1 - Reconstruction:")
print("=================================")
audio = emb_model.decode(reconstruction_01[0]).squeeze()
IPython.display.display(IPython.display.Audio(audio.cpu().detach().numpy(), rate=hierarchical_handler.sample_rate))

print(f"Audio 2 - Reconstruction:")
print("=================================")
audio = emb_model.decode(reconstruction_02[0]).squeeze()
IPython.display.display(IPython.display.Audio(audio.cpu().detach().numpy(), rate=hierarchical_handler.sample_rate))


# MANIPULATION
x_T = sample_noise("audio", 2, 64, 32, DEVICE)

print(f"MIX 1 - 1 - 2 :")
print("=================================")
z_sem = [None, z_sem_01[1], z_sem_01[2], z_sem_02[3]]
inference, _ = hierarchical_handler.sampler.sample(hierarchical_handler.diffusion_unet, x_T, z_hierarchical=z_sem)  
audio = emb_model.decode(inference[0]).squeeze()
IPython.display.display(IPython.display.Audio(audio, rate=hierarchical_handler.sample_rate))

print(f"MIX 1 - 2 - 1 :")
print("=================================")
z_sem = [None, z_sem_01[1], z_sem_02[2], z_sem_01[3]]
inference, _ = hierarchical_handler.sampler.sample(hierarchical_handler.diffusion_unet, x_T, z_hierarchical=z_sem)  
audio = emb_model.decode(inference[0]).squeeze()
IPython.display.display(IPython.display.Audio(audio, rate=hierarchical_handler.sample_rate))

print(f"MIX 2 - 1 - 1 :")
print("=================================")
z_sem = [None, z_sem_02[1], z_sem_01[2], z_sem_01[3]]
inference, _ = hierarchical_handler.sampler.sample(hierarchical_handler.diffusion_unet, x_T, z_hierarchical=z_sem)  
audio = emb_model.decode(inference[0]).squeeze()
IPython.display.display(IPython.display.Audio(audio, rate=hierarchical_handler.sample_rate))

print(f"MIX 2 - 2 - 1 :")
print("=================================")
z_sem = [None, z_sem_02[1], z_sem_02[2], z_sem_01[3]]
inference, _ = hierarchical_handler.sampler.sample(hierarchical_handler.diffusion_unet, x_T, z_hierarchical=z_sem)  
audio = emb_model.decode(inference[0]).squeeze()
IPython.display.display(IPython.display.Audio(audio, rate=hierarchical_handler.sample_rate))

print(f"MIX 2 - 1 - 2 :")
print("=================================")
z_sem = [None, z_sem_02[1], z_sem_01[2], z_sem_02[3]]
inference, _ = hierarchical_handler.sampler.sample(hierarchical_handler.diffusion_unet, x_T, z_hierarchical=z_sem)  
audio = emb_model.decode(inference[0]).squeeze()
IPython.display.display(IPython.display.Audio(audio, rate=hierarchical_handler.sample_rate))

print(f"MIX 1 - 2 - 2 :")
print("=================================")
z_sem = [None, z_sem_01[1], z_sem_02[2], z_sem_02[3]]
inference, _ = hierarchical_handler.sampler.sample(hierarchical_handler.diffusion_unet, x_T, z_hierarchical=z_sem)  
audio = emb_model.decode(inference[0]).squeeze()
IPython.display.display(IPython.display.Audio(audio, rate=hierarchical_handler.sample_rate))



Audio 1 - Reconstruction:


Audio 2 - Reconstruction:


MIX 1 - 1 - 2 :


MIX 1 - 2 - 1 :


MIX 2 - 1 - 1 :


MIX 2 - 2 - 1 :


MIX 2 - 1 - 2 :


MIX 1 - 2 - 2 :


# Multi-scale spectral hierarchical diffusion auto-encoder

In [60]:
chroma_handler = load_music2latent_model(
    "chroma",
    "music2latent_mel_hdae_02",
    400,
    DATASET,
    dataloader_test,
    DEVICE,
)

STFT kernels created, time used = 0.1910 seconds
STFT filter created, time used = 0.0067 seconds
Mel filter created, time used = 0.0067 seconds
STFT kernels created, time used = 0.0337 seconds
STFT filter created, time used = 0.0014 seconds
Mel filter created, time used = 0.0014 seconds
STFT kernels created, time used = 0.0094 seconds
STFT filter created, time used = 0.0018 seconds
Mel filter created, time used = 0.0018 seconds
Loading from diffae/src/saved_models/slakh:music2latent:chroma_dae:mel:config02
--------------- Model Infos ----------------
Diffusion Unet: 20.134656 M Params
Semantic Encoder: 3.127884 M Params


In [62]:
audio_01 = test_dataset[10]["waveform"]
audio_02 = test_dataset[44]["waveform"]

audio_01 = torch.from_numpy(audio_01[...,:chroma_handler.audio_length]).to(DEVICE).repeat(2, 1, 1).float()
audio_02 = torch.from_numpy(audio_02[...,:chroma_handler.audio_length]).to(DEVICE).repeat(2, 1, 1).float()

z_sem_01 = chroma_handler.semantic_encoder(audio_01)
z_sem_02 = chroma_handler.semantic_encoder(audio_02)

reconstruction_01, _ = chroma_handler.encode_decode(audio_01) 
reconstruction_02, _ = chroma_handler.encode_decode(audio_02) 

In [64]:
print(f"Audio 1 - Reconstruction:")
print("=================================")
audio = emb_model.decode(reconstruction_01[0]).squeeze()
IPython.display.display(IPython.display.Audio(audio.cpu().detach().numpy(), rate=chroma_handler.sample_rate))

print(f"Audio 2 - Reconstruction:")
print("=================================")
audio = emb_model.decode(reconstruction_02[0]).squeeze()
IPython.display.display(IPython.display.Audio(audio.cpu().detach().numpy(), rate=chroma_handler.sample_rate))


# MANIPULATION
x_T = sample_noise("audio", 2, 64, 32, DEVICE)

print(f"MIX 1 - 1 - 2 :")
print("=================================")
z_sem = [None, z_sem_01[1], z_sem_01[2], z_sem_02[3]]
inference, _ = chroma_handler.sampler.sample(chroma_handler.diffusion_unet, x_T, z_hierarchical=z_sem)  
audio = emb_model.decode(inference[0]).squeeze()
IPython.display.display(IPython.display.Audio(audio, rate=chroma_handler.sample_rate))

print(f"MIX 1 - 2 - 1 :")
print("=================================")
z_sem = [None, z_sem_01[1], z_sem_02[2], z_sem_01[3]]
inference, _ = chroma_handler.sampler.sample(chroma_handler.diffusion_unet, x_T, z_hierarchical=z_sem)  
audio = emb_model.decode(inference[0]).squeeze()
IPython.display.display(IPython.display.Audio(audio, rate=chroma_handler.sample_rate))

print(f"MIX 2 - 1 - 1 :")
print("=================================")
z_sem = [None, z_sem_02[1], z_sem_01[2], z_sem_01[3]]
inference, _ = chroma_handler.sampler.sample(chroma_handler.diffusion_unet, x_T, z_hierarchical=z_sem)  
audio = emb_model.decode(inference[0]).squeeze()
IPython.display.display(IPython.display.Audio(audio, rate=chroma_handler.sample_rate))

print(f"MIX 2 - 2 - 1 :")
print("=================================")
z_sem = [None, z_sem_02[1], z_sem_02[2], z_sem_01[3]]
inference, _ = chroma_handler.sampler.sample(chroma_handler.diffusion_unet, x_T, z_hierarchical=z_sem)  
audio = emb_model.decode(inference[0]).squeeze()
IPython.display.display(IPython.display.Audio(audio, rate=chroma_handler.sample_rate))

print(f"MIX 2 - 1 - 2 :")
print("=================================")
z_sem = [None, z_sem_02[1], z_sem_01[2], z_sem_02[3]]
inference, _ = chroma_handler.sampler.sample(chroma_handler.diffusion_unet, x_T, z_hierarchical=z_sem)  
audio = emb_model.decode(inference[0]).squeeze()
IPython.display.display(IPython.display.Audio(audio, rate=chroma_handler.sample_rate))

print(f"MIX 1 - 2 - 2 :")
print("=================================")
z_sem = [None, z_sem_01[1], z_sem_02[2], z_sem_02[3]]
inference, _ = chroma_handler.sampler.sample(chroma_handler.diffusion_unet, x_T, z_hierarchical=z_sem)  
audio = emb_model.decode(inference[0]).squeeze()
IPython.display.display(IPython.display.Audio(audio, rate=chroma_handler.sample_rate))

Audio 1 - Reconstruction:


Audio 2 - Reconstruction:


MIX 1 - 1 - 2 :


MIX 1 - 2 - 1 :


MIX 2 - 1 - 1 :


MIX 2 - 2 - 1 :


MIX 2 - 1 - 2 :


MIX 1 - 2 - 2 :
