# Audio-to-audio generation 

<div style="text-align:center;">
<img src="../images/method.png" alt="Example Image" width="800" />
</div>


This notebook implements the inference for audio-to-audio generation. We demonstrate using the demo samples from the [webpage](https://nilsdem.github.io/control-transfer-diffusion/), but you can load your own structure and timbre targets. 
Please note that although any structure input can be used, the model require samples from the datasets (or quite similar) for the timbre target.


Make sure to [download]() the pretrained models and place them in `./pretrained`. Two pretrained models are available, one trained on [SLAKH 2100](http://www.slakh.com/), and one trained on multiple real-world instrumental recordings (Maestro, URMP, Filobass, GuitarSet...).

In [None]:
import os

os.chdir("..")  # eventually change working directory to root of the project

In [None]:
import gin

gin.enter_interactive_mode()

from IPython.display import display, Audio
import torch
import numpy as np
import librosa

torch.set_grad_enabled(False)

### Checkpoint setup

In [None]:
# Import paths
folder = "./pretrained/slakh/"
checkpoint_path = folder + "checkpoint.pt"
autoencoder_path = "./pretrained/AE_slakh.pt"
config = folder + "config.gin"

# GPU
device = "cuda:0"

### Instantiate te model and load the checkpoint

In [None]:
from diffusion.model import EDM_ADV

# Parse config
gin.parse_config_file(config)
SR = gin.query_parameter("%SR")
audio_length = gin.query_parameter("%X_LENGTH")

# Instantiate model
blender = EDM_ADV()

# Load checkpoints
state_dict = torch.load(checkpoint_path)["model_state"]
blender.load_state_dict(state_dict, strict=False)

emb_model = torch.jit.load(autoencoder_path).eval().to(device)

# Send to device
blender = blender.eval().to(device)

#### Loading some audio files

In [None]:
path1 = './audios/slakh/true/piano_guitar_1.wav'
path2 = './audios/slakh/target/piano_guitar_1.wav'

In [None]:
def load_audio(path, sr):
    audio_full, sr = librosa.load(path, sr=sr)
    audio = audio_full[:audio_length]
    audio = torch.from_numpy(audio).reshape(1, 1, -1) / audio.max()
    return audio


def process_audio(audio):
    audio = audio.to(device)
    z = emb_model.encode(audio)
    cqt = blender.time_transform(audio)
    cqt = torch.nn.functional.interpolate(cqt,
                                          size=(z.shape[-1]),
                                          mode="nearest")
    cqt = (cqt - torch.min(cqt)) / (torch.max(cqt) - torch.min(cqt) + 1e-4)
    return z, cqt

In [None]:
audio1, audio2 = load_audio(path1, sr=SR), load_audio(path2, sr=SR)

print("Sample 1")
display(Audio(audio1.squeeze(), rate=SR))
print("Sample 2")
display(Audio(audio2.squeeze(), rate=SR))

# Compute embeddings and CQT
z1, cqt1 = process_audio(audio1)
z2, cqt2 = process_audio(audio2)

#### Generation

In [None]:
nb_steps = 40  #Number of diffusion steps
guidance = 2.0  #Classifier free guidance strength

In [None]:
# Compute structure representation
time_cond1, time_cond2 = blender.encoder_time(cqt1), blender.encoder_time(cqt2)

# Compute timbre representation
zsem1, zsem2 = blender.encoder(z1), blender.encoder(z2)

# Sample initial noise
x0 = torch.randn_like(z1)

print("Timbre of sample 1 and structure of sample 2")
xS = blender.sample(x0,
                    time_cond=time_cond2,
                    zsem=zsem1,
                    nb_step=nb_steps,
                    guidance=guidance,
                    guidance_type="time_cond")

audio_out = emb_model.decode(xS).cpu().numpy().squeeze()
display(Audio(audio_out, rate=SR))

print("Timbre of sample 2 and structure of sample 1")
xS = blender.sample(x0,
                    time_cond=time_cond1,
                    zsem=zsem2,
                    nb_step=nb_steps,
                    guidance=guidance,
                    guidance_type="time_cond")

audio_out = emb_model.decode(xS).cpu().numpy().squeeze()
display(Audio(audio_out, rate=SR))