# MIDI-to-audio generation 


In [None]:
import gin

gin.enter_interactive_mode()

from IPython.display import display, Audio
import torch
import numpy as np
import librosa
import matplotlib.pyplot as plt
import sys
import os

sys.path.append('..')
torch.set_grad_enabled(False)

## Model Loading

Model can be loaded from a training checkpoint (.pt file stored by default in ./after_runs/#model_name), or from an already exported .ts file.

### Load a model from a training Checkpoint

In [2]:
model_path = ""
step = None  # Use None to load the last
autoencoder_path = ""
device = "cuda:0"

### Instantiate te model and load the checkpoint

In [3]:
class CheckpointModel(torch.nn.Module):

    def __init__(self, model, emb_model, device=device):
        super().__init__()
        self.model = model
        self.emb_model = emb_model
        self.device = device
        self.to(device)

    def ae_encode(self, x):
        if len(x.shape) > 1:
            x = x.reshape(x.shape[0], 1, -1)  # Flatten the input
        else:
            x = x.reshape(1, 1, -1)

        return self.emb_model.encode(x.to(self.device))

    def ae_decode(self, z):
        return self.emb_model.decode(z.to(self.device)).cpu().squeeze()

    def timbre(self, z):
        return self.model.encoder(z.to(self.device))

    def sample(self, noise, z_structure, z_timbre, guidance_timbre,
               guidance_structure, nb_steps):

        zout = self.model.sample(noise.to(self.device),
                                 time_cond=z_structure.to(self.device),
                                 cond=z_timbre.to(self.device),
                                 nb_steps=nb_steps,
                                 guidance_structure=guidance_structure,
                                 guidance_timbre=guidance_timbre)
        return zout

In [None]:
from after.diffusion import RectifiedFlow

if step is None:
    files = os.listdir(model_path)
    files = [f for f in files if f.startswith("checkpoint")]
    steps = [f.split("_")[-2].replace("checkpoint", "") for f in files]
    step = max([int(s) for s in steps])
    checkpoint_file = "checkpoint" + str(step) + "_EMA.pt"
else:
    checkpoint_file = "checkpoint" + str(step) + "_EMA.pt"

checkpoint_file = os.path.join(model_path, checkpoint_file)
config = os.path.join(model_path, "config.gin")

# Parse config
gin.parse_config_file(config)
SR = gin.query_parameter("%SR")
n_signal = gin.query_parameter("%N_SIGNAL")
latent_size = gin.query_parameter("%IN_SIZE")
# Emb model

# Instantiate model
blender = RectifiedFlow(device=device)

# Load checkpoints
state_dict = torch.load(checkpoint_file, map_location="cpu")["model_state"]
blender.load_state_dict(state_dict, strict=False)

# Emb model
emb_model = torch.jit.load(autoencoder_path).eval()

dummy = torch.randn(1, 1, 8192)  # Dummy input for model initialization
z = emb_model.encode(dummy)

ae_ratio = dummy.shape[-1] / z.shape[-1]

# Send to device
blender = blender.eval()

model = CheckpointModel(blender, emb_model, device=device)


## Load a model from a torchscript

In [21]:
ts_path = ""
autoencoder_path = ""

ts_model = torch.jit.load(ts_path)
emb_model = torch.jit.load(autoencoder_path).eval()
# You can set None if the autoencoder path is not known, but it will use the streaming model embedded in the .ts which has some latency


In [27]:
class TSModel(torch.nn.Module):

    def __init__(self, ts_model, emb_model=None):
        super().__init__()
        self.model = ts_model
        self.emb_model = emb_model

    def ae_encode(self, x):
        if len(x.shape) > 1:
            x = x.reshape(x.shape[0], 1, -1)  # Flatten the input
        else:
            x = x.reshape(1, 1, -1)

        if self.emb_model is not None:
            return self.emb_model.encode(x)
        return self.model.emb_model_structure.encode(x)

    def ae_decode(self, z):
        if self.emb_model is not None:
            return self.emb_model.decode(z).squeeze().cpu()
        return self.model.emb_model_timbre.decode(z).squeeze().cpu()

    def timbre(self, z):
        return self.model.encoder.forward_stream(z)

    def structure(self, z):
        return self.model.encoder_time.forward_stream(z)

    def sample(self, noise, z_structure, z_timbre, guidance_timbre,
               guidance_structure, nb_steps):

        self.model.set_guidance_timbre(guidance_timbre)
        self.model.set_guidance_structure(guidance_structure)
        self.model.set_nb_steps(nb_steps)

        zout = self.model.sample(noise, time_cond=z_structure, cond=z_timbre)

        return zout

In [28]:
model = TSModel(ts_model, emb_model=emb_model)

## Audio Samples

You can either load a sample from a prepared lmdb database, or directly from audio files (below). By defaut, audio files are cut to the training length.

#### Load audio from the dataset

In [18]:
from after.dataset import SimpleDataset
from IPython.display import display, Audio

db_path = ""
dataset = SimpleDataset(path=db_path, keys=["z", "midi"])

In [None]:
d1 = dataset[380]  # Example for timbre
d2 = dataset[540]  # Example for midi

z1 = d1["z"][..., :n_signal]  # guitar
z2 = d2["z"][..., :n_signal]

z1, z2 = torch.tensor(z1).unsqueeze(0), torch.tensor(z2).unsqueeze(0)


def normalize(array):
    return (array - array.min()) / (array.max() - array.min() + 1e-6)


ae_ratio = gin.query_parameter("utils.collate_fn.ae_ratio")
full_length = dataset[0]["z"].shape[-1]
times = times = np.linspace(0, full_length * ae_ratio / SR, full_length)
pr = d1["midi"].get_piano_roll(times=times)
pr = pr / 127
pr = pr[..., :n_signal]
pr = torch.from_numpy(pr).float().reshape(1, 128, -1)

x1, x2 = model.ae_decode(z1), model.ae_decode(z2)

print("Audio for timbre")
display(Audio(x1, rate=SR))
plt.show()

print("Midi Example")
display(Audio(x2, rate=SR))
plt.imshow(pr[0].cpu().numpy(), aspect="auto", origin="lower")
plt.show()

#### Load audio from files

In [14]:
audio_path = ""
midi_path = ""

offset_midi = 0  #Start moment in the midi file
duration = 10  # Duration of the midi chunk and generated audio = the model will generate audio based on the midi information between offset and offset + duration

offset_audio = 0  #Start moment in the audio file. Duration will be based on the signal length seen by the model during training
duration_audio = 10

In [None]:
x1, _ = librosa.load(audio_path,
                     sr=SR,
                     mono=True,
                     offset=offset_audio,
                     duration=10)

print(x1.shape)
display(Audio(x1, rate=SR))

real_duration_samples = duration * SR // ae_ratio * ae_ratio
real_duration_time = real_duration_samples / SR
z_length = real_duration_samples // ae_ratio

x1 = torch.tensor(x1)
z1 = model.ae_encode(x1)[..., :z_length]

# Get the midi
import pretty_midi

midi = pretty_midi.PrettyMIDI(midi_path)
ae_ratio = gin.query_parameter("utils.collate_fn.ae_ratio")

full_length = dataset[0]["z"].shape[-1]
times = times = np.linspace(offset_midi, offset_midi + real_duration_time,
                            z_length)
pr = midi.get_piano_roll(times=times)
pr = pr / 127
# pr = pr[..., :n_signal]
pr = torch.from_numpy(pr).float().reshape(1, 128, -1)

print("Piano Roll")
plt.imshow(pr[0].cpu().numpy(), aspect="auto", origin="lower")
plt.show()

## Generation

In [16]:
nb_steps = 15  #Number of diffusion steps
guidance_timbre = 2.0  #Classifier free guidance strength for timbre
guidance_structure = 3.0  #Classifier free guidance strength strucutre

In [None]:
# Compute structure representation
z_structure = pr

# Compute timbre representation -  Timbre must be computed on latent codes of length n_signal - by default we use the first n_signal elements

z_timbre = model.timbre(z1[..., :n_signal])

# Sample initial noise
noise = torch.randn_like(z1)

print("Transfer")

xS = model.sample(
    noise=noise,
    z_structure=z_structure,
    z_timbre=z_timbre,
    nb_steps=nb_steps,
    guidance_structure=guidance_structure,
    guidance_timbre=guidance_timbre,
)

audio_out = model.ae_decode(xS)
display(Audio(audio_out, rate=SR))