# Audio-to-audio generation 


In [None]:
import gin

gin.enter_interactive_mode()
from IPython.display import display, Audio
import torch
import numpy as np
import librosa
import sys
import os

sys.path.append('..')
torch.set_grad_enabled(False)

## Model Loading

Model can be loaded from a training checkpoint (.pt file stored by default in ./after_runs/#model_name), or from an already exported .ts file.

### Load a model from a training Checkpoint

In [14]:
model_path = ""
step = None  #Use None for last step
autoencoder_path = ""
device = "cuda:0"

In [15]:
class CheckpointModel(torch.nn.Module):

    def __init__(self, model, emb_model, device=device):
        super().__init__()
        self.model = model
        self.emb_model = emb_model
        self.device = device
        self.to(device)

    def ae_encode(self, x):
        if len(x.shape) > 1:
            x = x.reshape(x.shape[0], 1, -1)  # Flatten the input
        else:
            x = x.reshape(1, 1, -1)
        return self.emb_model.encode(x.to(self.device))

    def ae_decode(self, z):
        return self.emb_model.decode(z.to(self.device)).cpu().squeeze()

    def timbre(self, z):
        return self.model.encoder(z.to(self.device))

    def structure(self, z):
        return self.model.encoder_time(z.to(self.device))

    def sample(self, noise, z_structure, z_timbre, guidance_timbre,
               guidance_structure, nb_steps):

        zout = self.model.sample(noise,
                                 time_cond=z_structure,
                                 cond=z_timbre,
                                 nb_steps=nb_steps,
                                 guidance_structure=guidance_structure,
                                 guidance_timbre=guidance_timbre)
        return zout

In [None]:
from after.diffusion import RectifiedFlow

if step is None:
    files = os.listdir(model_path)
    files = [f for f in files if f.startswith("checkpoint")]
    steps = [f.split("_")[-2].replace("checkpoint", "") for f in files]
    step = max([int(s) for s in steps])
    checkpoint_file = "checkpoint" + str(step) + "_EMA.pt"
else:
    checkpoint_file = "checkpoint" + str(step) + "_EMA.pt"

checkpoint_file = os.path.join(model_path, checkpoint_file)
config = os.path.join(model_path, "config.gin")

# Parse config
gin.parse_config_file(config)
SR = gin.query_parameter("%SR")
n_signal = gin.query_parameter("%N_SIGNAL")
latent_size = gin.query_parameter("%IN_SIZE")
# Emb model

# Instantiate model
blender = RectifiedFlow(device=device)

# Load checkpoints
state_dict = torch.load(checkpoint_file, map_location="cpu")["model_state"]
blender.load_state_dict(state_dict, strict=False)

# Emb model
emb_model = torch.jit.load(autoencoder_path).eval()

dummy = torch.randn(1, 1, 8192)  # Dummy input for model initialization
z = emb_model.encode(dummy)

ae_ratio = dummy.shape[-1] / z.shape[-1]

# Send to device
blender = blender.eval()

model = CheckpointModel(blender, emb_model, device=device)


## Load a model from a torchscript

In [None]:
ts_path = ""
autoencoder_path = ""

ts_model = torch.jit.load(ts_path)
emb_model = torch.jit.load(autoencoder_path).eval(
)  # You can set None if the autoencoder path is not known, but it will use the streaming model embedded in the .ts which has some latency


In [27]:
class TSModel(torch.nn.Module):

    def __init__(self, ts_model, emb_model=None):
        super().__init__()
        self.model = ts_model
        self.emb_model = emb_model

    def ae_encode(self, x):
        if len(x.shape) > 1:
            x = x.reshape(x.shape[0], 1, -1)  # Flatten the input
        else:
            x = x.reshape(1, 1, -1)

        if self.emb_model is not None:
            return self.emb_model.encode(x)
        return self.model.emb_model_structure.encode(x)

    def ae_decode(self, z):
        if self.emb_model is not None:
            return self.emb_model.decode(z).squeeze().cpu()
        return self.model.emb_model_structure.decode(z).squeeze().cpu()

    def timbre(self, z):
        return self.model.encoder.forward_stream(z)

    def structure(self, z):
        return self.model.encoder_time.forward_stream(z)

    def sample(self, noise, z_structure, z_timbre, guidance_timbre,
               guidance_structure, nb_steps):

        self.model.set_guidance_timbre(guidance_timbre)
        self.model.set_guidance_structure(guidance_structure)
        self.model.set_nb_steps(nb_steps)

        zout = self.model.sample(noise, time_cond=z_structure, cond=z_timbre)

        return zout

In [28]:
model = TSModel(ts_model, emb_model=emb_model)

## Audio Samples

You can either load a sample from a prepared lmdb database, or directly from audio files (below). By defaut, audio files are cut to the training length.

#### Load audio from the dataset

In [29]:
from after.dataset import SimpleDataset

db_path = ""
dataset = SimpleDataset(path=db_path, keys=["z"])


In [None]:
z1 = dataset[1]["z"]  #[..., :n_signal]
z2 = dataset[2890]["z"]  #[..., :n_signal]

z1, z2 = torch.tensor(z1).unsqueeze(0), torch.tensor(z2).unsqueeze(0)

x1, x2 = model.ae_decode(z1), model.ae_decode(z2)

display(Audio(x1, rate=SR))
display(Audio(x2, rate=SR))

#### Load audio from files

In [None]:
path1 = ""
path2 = ""

x1, _ = librosa.load(path1, sr=SR, mono=True)

x2, _ = librosa.load(path2, sr=SR, mono=True)
print(x1.shape)
display(Audio(x1, rate=SR))
display(Audio(x2, rate=SR))

x1, x2 = torch.tensor(x1), torch.tensor(x2)
z1, z2 = model.ae_encode(x1), model.ae_encode(x2)

#### Generation

In [32]:
nb_steps = 10  #Number of diffusion steps
guidance_timbre = 2.0  #Classifier free guidance strength for timbre
guidance_structure = 1.0  #Classifier free guidance strength strucutre

In [33]:
# Compute structure representation
z_structure1, z_structure2 = model.structure(z1), model.structure(z2)

# Compute timbre representation -  Timbre must be computed on latent codes of length n_signal - by default we use the first n_signal elements

z_timbre1, z_timbre2 = model.timbre(z1[..., :n_signal]), model.timbre(
    z2[..., :n_signal])

z_structure = z_structure2
z_timbre = z_timbre1

# Sample initial noise
noise = torch.randn_like(z1)

print("Transfer")

xS = model.sample(
    noise=noise,
    z_structure=z_structure,
    z_timbre=z_timbre,
    nb_steps=nb_steps,
    guidance_structure=guidance_structure,
    guidance_timbre=guidance_timbre,
)

audio_out = model.ae_decode(xS)
display(Audio(audio_out, rate=SR))