# Sound Generation

## Overview

![System](docs/SG_System.jpg)

## Encoder

![Encoder](docs/SG_Encoder.jpg)

## Decoder

![Decoder](docs/SG_Decoder.jpg)

In [56]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
import numpy as np
import soundfile as sf
import tsms
from tcvae import dataset, localconfig, model, data_handler
from tcvae.compute_measures import heuristic_names
from IPython.display import Audio

## Load Config, Data and Model

In [57]:
conf = localconfig.LocalConfig()
conf.load_config_from_file("checkpoints/mt_categorical_256/conf.txt")
conf.simple_encoder = False
conf.simple_decoder = False

conf.dataset_dir = "complete_dataset"
conf.batch_size = 1

measure_to_index = dict((n, i) for i, n in enumerate(heuristic_names))
index_to_measure = dict((v, k) for k, v in measure_to_index.items())

print("Configuration loaded")

_, valid, test = dataset.get_dataset(conf)
test_iter = iter(test)
valid_iter = iter(valid)

print("Data loaded")

mt_model = model.MtVae(conf)
_ = mt_model(next(valid_iter))

mt_model.load_weights("checkpoints/mt_categorical_256/38_mt_categorical_256_12.63.h5")

print("Model loaded")

Configuration loaded
Data loaded
Model loaded


In [58]:
mt_model.decoder

<tensorflow.python.keras.engine.functional.Functional at 0x15b46dc1d00>

## Helper Functions

In [59]:
def change_params(batch, params):
    note_number = params.get("note_number")
    velocity = params.get("velocity")
    measures = params.get("measures")
    
    original_note = np.argmax(batch["note_number"], axis=-1)[0] + conf.starting_midi_pitch
    original_vel = (np.argmax(batch["velocity"], axis=-1)[0] + 1) * 25
    original_measures = batch["measures"]
    
    if note_number is not None:
        assert 40 <= note_number <= 88
        print(f"note_number changed from {original_note} to {note_number}")
        note_number -= conf.starting_midi_pitch
        updated_note = np.zeros((1, conf.num_pitches))
        updated_note[:, note_number] = 1.
        batch["note_number"] = updated_note
    if velocity is not None:
        assert 25 <= velocity <= 127
        print(f"velocity changed from {original_vel} to {velocity}")
        velocity = int(velocity / 25 - 1)
        updated_vel = np.zeros((1, conf.num_velocities))
        updated_vel[:, velocity] = 1.
        batch["velocity"] = updated_vel
    if measure_to_index is not None:
        if bool(measures):
            updated_measures = original_measures.numpy()
            for m, val in measures.items():
                assert m in heuristic_names
                original_value = updated_measures[:, measure_to_index[m]][0]
                print(f"{m} changed from {original_value} to {val}")
                updated_measures[:, measure_to_index[m]] = val
            batch["measures"] = updated_measures
    return batch


def get_prediction(batch, conf, prediction=None):
    batch = batch.copy()
    note_number = np.argmax(batch["note_number"], axis=-1) + conf.starting_midi_pitch
    if prediction is not None:
        transform = conf.data_handler.prediction_transform(prediction)
    else:
        transform = batch
    f, m, p = conf.data_handler.denormalize(transform, batch["mask"], note_number)
    audio = tsms.core.harmonic_synthesis(f, m, p, conf.sample_rate, conf.frame_size)
    return np.squeeze(audio.numpy())

def encode_decode_batch(normalized_data):
    def encode_decode(x, range_0_1=False):
        q = 256
        y = data_handler.mu_law_encode(x, q, range_0_1=range_0_1)
        y = data_handler.mu_law_decode(y, q, range_0_1=range_0_1)
        return y
    
    normalized_data["f0_shifts"] = encode_decode(
        normalized_data["f0_shifts"])

    normalized_data["mag_env"] = encode_decode(
        normalized_data["mag_env"], range_0_1=True)

    normalized_data["h_freq_shifts"] = encode_decode(
        normalized_data["h_freq_shifts"])

    normalized_data["h_mag_dist"] = encode_decode(
        normalized_data["h_mag_dist"], range_0_1=True)
    
    return normalized_data

def get_audios(batch, update_params=None, use_encoder=True, encode_decode=False):
    batch = batch.copy()

    audio_gt = get_prediction(batch, conf, prediction=None)
    
    if use_encoder:
        batch["z"] = mt_model.encoder.predict(batch)
    
    if update_params is not None:
        batch = change_params(batch, update_params)
        if "z" in update_params:
            updated_z = update_params.get("z")
            updated_z = np.expand_dims(z, axis=0)
            assert updated_z.shape == (1, 16)
            print("Updating Z from user input")
            batch["z"] = updated_z
    
    if "z" not in batch:
        print("Updating z from random values")
        batch["z"] = np.random.rand(1, 16)
    
    if encode_decode:
        batch_ed = encode_decode_batch(batch)
        audio_pred = get_prediction(batch_ed, conf, prediction=None)
    else:
        prediction = mt_model.decoder.predict(batch)
        audio_pred = get_prediction(batch, conf, prediction=prediction)

    return audio_pred, audio_gt

## Get Predictions

In [36]:
batch = next(valid_iter)

In [37]:
update_params = {
    # "note_number": 49,
    # "velocity": 127,
    "measures": {
        # "bass": 1.,
        # "mid": 1.,
        # "high_mid": 0.,
        # "high": 1.,
        # "inharmonicity": 1.,
        # "even_odd": 1.,
        # "sparse_rich": 1.,
        # "decay_time": 0.3
    }
}

audio_pred, audio_gt = get_audios(batch, update_params, use_encoder=True)

In [38]:
Audio(audio_gt, rate=conf.sample_rate)

In [39]:
Audio(audio_pred, rate=conf.sample_rate)


## Export Audio Files

In [61]:
test_iter = iter(test)

def write_audio(audio, conf, audio_path):
    audio = audio / np.max(np.abs(audio))
    sf.write(audio_path, audio, samplerate=conf.sample_rate)


for i in range(0, 20):
    batch = next(test_iter)

    audio_pred, audio_gt = get_audios(batch, encode_decode=False)

    true_path = os.path.join(os.getcwd(), "predictions", f"{i}_true.wav")
    pred_path = os.path.join(os.getcwd(), "predictions", f"{i}_pred.wav")

    write_audio(audio_pred, conf, pred_path)
    write_audio(audio_gt, conf, true_path)

    print(i)

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
