In [None]:
import torch
import torchaudio
from einops import rearrange
from stable_audio_tools import get_pretrained_model
from stable_audio_tools.inference.generation import generate_diffusion_cond
from tqdm.notebook import tqdm
from AudacityHelper import AudacityPipeline
import os

device = "cuda" if torch.cuda.is_available() else 'mps' if torch.mps.is_available() else "cpu"
print("Using {}".format(device))

#### Encode Thy Latents

In [None]:
# Choose your model! (reading this mentally with an announcer voice) Could use normal but we use small

In [None]:
# Download model | Stable Audio Open (small)
# `https://huggingface.co/stabilityai/stable-audio-open-small`
model, model_config = get_pretrained_model("stabilityai/stable-audio-open-small")
# model, model_config = get_pretrained_model("stabilityai/stable-audio-open-1.0")
sample_rate = model_config["sample_rate"]
sample_size = model_config["sample_size"]

model = model.to(device).eval()

In [None]:
# I have to use this to handle bad audiofiles ugh | to use this you have to have audacity open and have to enable scripting. (Ask GPT its pretty easy to do just a little finicky)
ap = AudacityPipeline()

In [None]:
def get_autoencoder(model):
    return model._modules['pretransform']._modules.get("model")

autoencoder = get_autoencoder(model).to(device)

sample_param = next(autoencoder.parameters())
audio_device = sample_param.device
audio_dtype = sample_param.dtype
print(f"Using {audio_device} device {audio_dtype} dtype")

def clean_audio_dim(audio, debug=False):
    if audio.dim() == 1: audio = audio.unsqueeze(0).unsqueeze(0)
    if audio.dim() == 2: audio = audio.unsqueeze(0)
    if audio.shape[1] == 1: audio = audio.repeat(1, 2, 1)
    audio = audio.to(device=audio_device, dtype=audio_dtype)
    if debug: print(f"Shape: {audio.shape} \n Device: {audio.device}")
    return audio

@torch.no_grad()
def encode_audio_latent(path_to_audio, autoencoder):
    audio, audio_sr = torchaudio.load(path_to_audio)
    audio = clean_audio_dim(audio)
    latents = autoencoder.encode(audio)
    return latents

def encode_audio_latents(list_of_audio_paths, autoencoder, save_to='data/audio_latents'):
    save_dir = os.path.abspath(save_to)
    os.makedirs(save_dir, exist_ok=True)

    errored_paths = []
    error_log = ''
    for path_to_audio in tqdm(list_of_audio_paths):
        audio_name = os.path.splitext(os.path.basename(path_to_audio))[0]
        try:
            latents = encode_audio_latent(path_to_audio, autoencoder)
            save_path = os.path.join(save_dir, f"{audio_name}.pt")
            torch.save(latents.cpu(), save_path)
        except Exception as e:
            message = f"Ran into error on file {audio_name}\nAttempting to fix with audacity:\n"
            error_log += f'{message}\n'
            print(message)
            new_path = ap.clean_audio_via_audacity(path_to_audio)
            if new_path:
                try:
                    latents = encode_audio_latent(new_path, autoencoder)
                    save_path = os.path.join(save_dir, f"{audio_name}.pt")
                    torch.save(latents.cpu(), save_path)
                    print("Successfully fixed with audacity processed file:", new_path)
                except Exception as e:
                    message = f"Failed to fix with new audacity processed file: {e}"
                    print(message)
                    error_log += f'{message}\n'
                    errored_paths.append(path_to_audio)
            else:
                errored_paths.append(path_to_audio)

    if error_log:
        with open(os.path.join(save_dir, 'error.log'), 'a') as f:
            f.write(error_log)
            if errored_paths:
                f.write('\nFailed files:\n')
                f.write('\n'.join(errored_paths))
            f.write('\n\n')

def get_audio_file_paths(folder, audio_data_path = "data/BDCT-0/"):
    audio_file_paths = []
    base_path = os.path.join(os.path.abspath(audio_data_path), folder)
    for file in os.listdir(os.path.join(base_path, 'Audio Files')):
        audio_file_paths += [os.path.join(base_path, 'Audio Files', file)]

    for file in os.listdir(os.path.join(base_path, 'Bounced Files')):
        audio_file_paths += [os.path.join(base_path, 'Bounced Files', file)]

    return audio_file_paths

In [None]:
focused_directory = 'UNVWTU'

In [None]:
audio_file_paths = get_audio_file_paths(focused_directory)

In [None]:
encode_audio_latents(audio_file_paths, autoencoder, save_to=f'data/audio_latents/{focused_directory}')

##### Playground

In [None]:
# Download model | Stable Audio Open (normal)
# `https://huggingface.co/stabilityai/stable-audio-open-1.0`
model, model_config = get_pretrained_model("stabilityai/stable-audio-open-1.0")
sample_rate = model_config["sample_rate"]
sample_size = model_config["sample_size"]

model = model.to(device)

In [None]:
# Set up text and timing conditioning
conditioning = [{
    "prompt": "60 BPM jazz saxophone solo",  # This prompt is quite bad on small, but small does work
    # "seconds_start": 0,
    "seconds_total": 11
}]

# Generate stereo audio
output = generate_diffusion_cond(
    model,
    steps=8,
    cfg_scale=1.0,
    conditioning=conditioning,
    sample_size=sample_size,
    # sigma_min=0.3,
    # sigma_max=500,
    # sampler_type="dpmpp-3m-sde",  # Use this for normal open
    sampler_type="pingpong",  # Use this for small
    device=device
)


# Rearrange audio batch to a single sequence
output = rearrange(output, "b d n -> d (b n)")

# Peak normalize, clip, convert to int16, and save to file
output = output.to(torch.float32).div(torch.max(torch.abs(output))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()

In [None]:
# output: (channels, samples) float32 on CPU, normalized safely
from IPython.display import Audio, display
display(Audio(output.numpy(), rate=sample_rate))

In [None]:
ae = model._modules['pretransform']._modules.get("model")

In [None]:
ae

In [None]:
encoder = ae.encoder

In [None]:
ae

In [None]:
encoder.__dict__

In [None]:
normal_audio = torchaudio.load('../normal_test.wav')[0].unsqueeze(0)

In [None]:
normal_audio.to('cpu')
ae.to('cpu')

In [None]:
latents, latent_info = ae.encode(normal_audio, return_info=True)

In [None]:
latents.shape

In [None]:
latent_info

In [None]:
from aeiou import viz

In [None]:
viz.tokens_spectrogram_image(latents)

In [None]:
# output: (channels, samples) float32 on CPU, normalized safely
from IPython.display import Audio, display
display(Audio(normal_audio.squeeze(0).numpy(), rate=44100))

In [None]:
viz.playable_spectrogram(normal_audio.squeeze(0), sample_rate=41000, output_type="live")