# Welcome to HW4!

# Setup

First, lets load in our model, and initialize our global variables of SAMPLE_RATE (i.e. the samples per second of the audio, in this case 44100), SAMPLE_SIZE (the *number* of audio samples we generate with the model, approximately 47.55*44100/8), and SEED (controls randomness, DO NOT CHANGE)

In [None]:
import torch
import torchaudio
from einops import rearrange
from stable_audio_tools import get_pretrained_model
import IPython.display as ipd
from tqdm.auto import trange, tqdm
from stable_audio_tools.inference.generation import generate_diffusion_cond_and_sampler_setup, generate_diffusion_cond_decode
import IPython.display as ipd
import gc

device = "cuda" if torch.cuda.is_available() else "cpu"
# torch.set_default_tensor_type(torch.Tensor)


# Download model
model, model_config = get_pretrained_model("stabilityai/stable-audio-open-1.0")
SAMPLE_RATE = model_config["sample_rate"]
SAMPLE_SIZE = model_config["sample_size"] // 8
SEED = 456

# set seeds
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False




model = model.half()

model = model.to(device)

In [None]:
# if you are using collab, uncomment out the following lines
# from google.colab import drive
# drive.mount('[/content/drive]')
# cd /content/drive/MyDrive/[path to your folder]
# pip install -e .
# pip install numpy==1.26.4
# pip install protobuf==3.20.1

# Q1 Simple Sampler

Here you should implement the to_d and simple_sample functions:

In [None]:

def to_d(x, sigma, denoised):
    # TODO
    pass

@torch.no_grad()
def simple_sample(model, x, sigmas, extra_args=None):
    extra_args = {} if extra_args is None else extra_args
    s_in = x.new_ones([x.shape[0]])
    for i in trange(len(sigmas) - 1):
        # TODO: add extra_args
    del extra_args
    torch.cuda.empty_cache()
    return x

In [None]:
def generate(prompt="128 BPM electronic drum loop", steps=50, cfg_scale=7, return_latents=False, x_start=None):

    # Set up text and timing conditioning
    conditioning = [{
        "prompt": prompt,
        "seconds_start": 0, 
        "seconds_total": 5
    }]

    # Generate diffusion setup params
    denoiser, x_T, sigmas, extra_args = generate_diffusion_cond_and_sampler_setup(
        model,
        steps=steps, # number of steps, more = better quality
        cfg_scale=cfg_scale, # Classifier-Free Guidance Scale, higher = better text relevance / quality but less diversity
        conditioning=conditioning,
        sample_size=SAMPLE_SIZE, # number of audio samples to generate, DON'T CHANGE
        device=device, # cuda device
        seed=SEED # random seed, DON'T CHANGE
    )

    if x_start is not None:
        x_T = x_start

    # Sample
    samples = simple_sample(denoiser, x_T, sigmas, extra_args=extra_args)
    del x_T
    del sigmas
    del extra_args
    torch.cuda.empty_cache()
    gc.collect()

    if return_latents:
        return samples

    # Decode
    audio = generate_diffusion_cond_decode(
        model,
        samples
    ).cpu()
    return audio



In [None]:
# to test your function, we provide some exampple latents to compare against
# the MSE between your latents and the reference latents should be low
for ix, prompt in enumerate(["lo-fi jazz piano in a rainy cafe", "deep ambient wash with ocean sounds"]):
    # load reference from testing_files
    ref = torch.load(f"testing_files/q1_{ix}.pt").to(device)

    # load x_T
    x_T = torch.load(f"x_T.pt").to(device)

    latents = generate(prompt=prompt, steps=50, cfg_scale=7, return_latents=True, x_start=x_T)
    # compare latents
    print(f"Latent MSE: {torch.nn.functional.mse_loss(ref, latents)}")


In [None]:
# for those running with a GPU, you can also generate audio samples
# this is unpractical for CPU, but you can try it if you want (we DO NOT recommend it)
# for ix, prompt in enumerate(["lo-fi jazz piano in a rainy cafe", "deep ambient wash with ocean sounds"]):
#     # load reference from testing_files
#     audio = generate(prompt="128 BPM electronic drum loop", steps=50, cfg_scale=7, return_latents=False)
#     # play audio
#     ipd.display(ipd.Audio(audio.cpu().numpy()[0], rate=SAMPLE_RATE))

# Q2 - Inpainting Mask

In [None]:
# LOAD AND ENCODE REFERENCE AUDIO
def load_and_encode_audio(path, model):
    audio, sr = torchaudio.load(path)
    # resample to SAMPLE_RATE
    resampler = torchaudio.transforms.Resample(sr, SAMPLE_RATE)
    sr = SAMPLE_RATE
    audio = resampler(audio)
    # peak normalize
    audio = audio / audio.abs().max()

    # trim to SAMPLE_SIZE if longer, pad with repetition if shorter
    if audio.shape[1] < SAMPLE_SIZE:
        while audio.shape[1] < SAMPLE_SIZE:
            audio = torch.cat((audio, audio), dim=1)

    audio = audio[:, :SAMPLE_SIZE][None].to(device)

    reference = model.pretransform.encode(audio)
    return reference


def load_encoded_audio(path):
    encoded_latent = torch.load(path)
    # check if the latent is in half precision
    return encoded_latent.half().to(device)
    



In [None]:
def generate_inpainting_mask(reference, mask_start_s, mask_end_s):
    # TODO
    return mask

In [None]:
# to test your function, we provide some exampple latents to compare against
# the MSE between your latents and the reference latents should be low
for ix in range(2):
    for midx, mask_range in enumerate([(0,4), (1,2), (1.5,3), (2,4), (1,4), (0,5)]):
        # load reference from testing_files
        ref = torch.load(f"testing_files/q2_{ix}.pt")[midx].to(device)
        # load reference audio
        reference = load_encoded_audio(f"testing_files/q1_{ix}.pt")
        # generate mask
        mask = generate_inpainting_mask(reference, *mask_range)
        # compare latents
        print(f"Latent MSE: {torch.nn.functional.mse_loss(ref, mask)}")



# Q3  - Inpainting

In [None]:
@torch.no_grad()
def simple_sample_inpaint(model, x, sigmas, reference, mask, extra_args=None):
    extra_args = {} if extra_args is None else extra_args
    s_in = x.new_ones([x.shape[0]])
    for i in trange(len(sigmas) - 1):
        # TODO
    del extra_args
    torch.cuda.empty_cache()
    return x


In [None]:
def inpaint(prompt="128 BPM house drum loop", steps=50, cfg_scale=7, reference=None, mask_start_s=20, mask_end_s=30, return_latents=False, x_start=None):
    # Set up text and timing conditioning
    conditioning = [{
        "prompt": prompt,
        "seconds_start": 0, 
        "seconds_total": 5
    }]
    # Set up inpainting mask
    mask = generate_inpainting_mask(reference, mask_start_s, mask_end_s)

    # Generate diffusion setup params
    denoiser, x_T, sigmas, extra_args = generate_diffusion_cond_and_sampler_setup(
        model,
        steps=steps,
        cfg_scale=cfg_scale,
        conditioning=conditioning,
        sample_size=SAMPLE_SIZE,
        device=device,
        seed=SEED
    )

    if x_start is not None:
        x_T = x_start

    # Sample
    inp_samples = simple_sample_inpaint(denoiser, x_T, sigmas, reference, mask, extra_args=extra_args)
    del x_T
    del sigmas
    del extra_args
    torch.cuda.empty_cache()
    gc.collect()

    if return_latents:
        return inp_samples

    # decode and play
    inpainted_audio = generate_diffusion_cond_decode(
        model,
        inp_samples
    ).cpu()
    return inpainted_audio



In [None]:
# to test your function, we provide some exampple latents to compare against
# the MSE between your latents and the reference latents should be low

for ix, prompt in enumerate(["lo-fi jazz piano in a rainy cafe", "deep ambient wash with ocean sounds"]):
    # load reference from testing_files
    ref = torch.load(f"testing_files/q3_{ix}.pt").to(device)
    # load reference audio
    reference = load_encoded_audio(f"testing_files/q1_{ix}.pt")
    # generate mask
    mask = generate_inpainting_mask(reference, 0, 3)

    # load x_T
    x_T = torch.load(f"x_T.pt").to(device)

    # generate inpainting
    latents = inpaint(prompt=prompt, steps=50, cfg_scale=7, reference=reference, mask_start_s=0, mask_end_s=3, return_latents=True, x_start=x_T)
    # compare latents
    print(f"Latent MSE: {torch.nn.functional.mse_loss(ref, latents)}")


In [None]:
# for those running with a GPU, you can also generate audio samples
# this is unpractical for CPU, but you can try it if you want (we DO NOT recommend it)
# for ix, prompt in enumerate(["lo-fi jazz piano in a rainy cafe", "deep ambient wash with ocean sounds"]):
#     # load reference from testing_files
#     reference = load_encoded_audio(f"testing_files/q1_{ix}.pt")
#     # generate mask
#     mask = generate_inpainting_mask(reference, 0, 3)
#     # generate inpainting
#     audio = inpaint(prompt=prompt, steps=50, cfg_scale=7, reference=reference, mask_start_s=0, mask_end_s=3, return_latents=False)
#     # play audio
#     ipd.display(ipd.Audio(audio.cpu().numpy()[0], rate=SAMPLE_RATE))


# Q4 Painting with Starting and Stopping Times

In [None]:
@torch.no_grad()
def simple_sample_variable_inpaint(model, x, sigmas, reference, mask, extra_args=None, paint_start=None, paint_end=None):
    if paint_start is None:
        paint_start = 0
    if paint_end is None:
        paint_end = len(sigmas) - 1
    extra_args = {} if extra_args is None else extra_args
    s_in = x.new_ones([x.shape[0]])
    for i in trange(len(sigmas) - 1):
        # TODO
    del extra_args
    torch.cuda.empty_cache()
    return x


In [None]:
def variable_inpaint(prompt="128 BPM house drum loop", steps=50, cfg_scale=7, reference=None, mask_start_s=20, mask_end_s=30, paint_start=None, paint_end=None, return_latents=False, x_start=None):
    # Set up text and timing conditioning
    conditioning = [{
        "prompt": prompt,
        "seconds_start": 0, 
        "seconds_total": 5
    }]
    # Set up inpainting mask
    mask = generate_inpainting_mask(reference, mask_start_s, mask_end_s)

    # Generate diffusion setup params
    denoiser, x_T, sigmas, extra_args = generate_diffusion_cond_and_sampler_setup(
        model,
        steps=steps,
        cfg_scale=cfg_scale,
        conditioning=conditioning,
        sample_size=SAMPLE_SIZE,
        device=device,
        seed=SEED
    )

    if x_start is not None:
        x_T = x_start

    # Sample
    inp_samples = simple_sample_variable_inpaint(denoiser, x_T, sigmas, reference, mask, extra_args=extra_args, paint_start=paint_start, paint_end=paint_end)
    del x_T
    del sigmas
    del extra_args
    torch.cuda.empty_cache()
    gc.collect()

    if return_latents:
        return inp_samples

    # decode and play
    inpainted_audio = generate_diffusion_cond_decode(
        model,
        inp_samples
    ).cpu()
    return inpainted_audio



In [None]:
# to test your function, we provide some exampple latents to compare against
# the MSE between your latents and the reference latents should be low

for ix, prompt in enumerate(["lo-fi jazz piano in a rainy cafe", "deep ambient wash with ocean sounds"]):
    # load reference from testing_files
    ref = torch.load(f"testing_files/q4_{ix}.pt").to(device)

    # load reference audio
    reference = load_encoded_audio(f"testing_files/q1_{ix}.pt")

    if ix == 0:
        mask = generate_inpainting_mask(reference, 0, 3)
        paint_start = 0
        paint_end = 20
    else:
        mask = generate_inpainting_mask(reference, 2, 5)
        paint_start = 15
        paint_end = 45

    # load x_T
    x_T = torch.load(f"x_T.pt").to(device)

    # generate inpainting
    latents = variable_inpaint(prompt=prompt, steps=50, cfg_scale=7, reference=reference, mask_start_s=0, mask_end_s=3, paint_start=paint_start, paint_end=paint_end, return_latents=True, x_start=x_T)
    # compare latents
    print(f"Latent MSE: {torch.nn.functional.mse_loss(ref, latents)}")



In [None]:
# for those running with a GPU, you can also generate audio samples
# this is unpractical for CPU, but you can try it if you want (we DO NOT recommend it)
# for ix, prompt in enumerate(["lo-fi jazz piano in a rainy cafe", "deep ambient wash with ocean sounds"]):
#     # load reference from testing_files
#     reference = load_encoded_audio(f"testing_files/q1_{ix}.pt")
#     # generate mask
#     mask = generate_inpainting_mask(reference, 0, 3)
#     # generate inpainting
#     if ix == 0:
#         paint_start = 0
#         paint_end = 20
#     else: 
#         paint_start = 15
#         paint_end = 45    
#     audio = variable_inpaint(prompt=prompt, steps=50, cfg_scale=7, reference=reference, mask_start_s=0, mask_end_s=3, paint_start=paint_start, paint_end=paint_end, return_latents=False)
#     # play audio
#     ipd.display(ipd.Audio(audio.cpu().numpy()[0], rate=SAMPLE_RATE))

# Q5 Style Transfer

In [None]:
def simple_sample_style_transfer(model, sigmas, reference, extra_args=None, transfer_strength=0):
    # TODO
    pass


In [None]:
def style_transfer(prompt="128 BPM house drum loop", steps=50, cfg_scale=7, reference=None, transfer_strength=0, return_latents=False, x_start=None):
    # Set up text and timing conditioning
    conditioning = [{
        "prompt": prompt,
        "seconds_start": 0, 
        "seconds_total": 5
    }]

    # Generate diffusion setup params
    denoiser, x_T, sigmas, extra_args = generate_diffusion_cond_and_sampler_setup(
        model,
        steps=steps,
        cfg_scale=cfg_scale,
        conditioning=conditioning,
        sample_size=SAMPLE_SIZE,
        device=device,
        seed=SEED
    )
    if x_start is not None:
        x_T = x_start

    # Sample
    inp_samples = simple_sample_style_transfer(denoiser, sigmas, reference, extra_args=extra_args, transfer_strength=transfer_strength)
    del x_T
    del sigmas
    del extra_args
    torch.cuda.empty_cache()
    gc.collect()

    if return_latents:
        return inp_samples
    
    # decode and play
    inpainted_audio = generate_diffusion_cond_decode(
        model,
        inp_samples
    ).cpu()
    return inpainted_audio
    

In [None]:
# to test your function, we provide some exampple latents to compare against
# the MSE between your latents and the reference latents should be low
for ix, prompt in enumerate(["deep ambient wash with ocean sounds", "lo-fi jazz piano in a rainy cafe"]):
    # load reference from testing_files
    ref = torch.load(f"testing_files/q5_{ix}.pt").to(device)
    # load reference audio
    reference = load_encoded_audio(f"testing_files/q1_{ix}.pt")
    if ix == 0:
        transfer_strength = 0.2
    else:
        transfer_strength = 0.5

    # load x_T
    x_T = torch.load(f"x_T.pt").to(device)

    # generate style transfer
    latents = style_transfer(prompt=prompt, steps=50, cfg_scale=7, reference=reference, transfer_strength=transfer_strength, return_latents=True, x_start=x_T)
    # compare latents
    print(f"Latent MSE: {torch.nn.functional.mse_loss(ref, latents)}")

In [None]:
# for those running with a GPU, you can also generate audio samples
# this is unpractical for CPU, but you can try it if you want (we DO NOT recommend it)
# for ix, prompt in enumerate(["lo-fi jazz piano in a rainy cafe", "deep ambient wash with ocean sounds"]):
#     # load reference from testing_files
#     reference = load_encoded_audio(f"testing_files/q1_{ix}.pt")
#     if ix == 0:
#         transfer_strength = 0.2
#     else:
#         transfer_strength = 0.5
#     # generate style transfer
#     audio = style_transfer(prompt=prompt, steps=50, cfg_scale=7, reference=reference, transfer_strength=transfer_strength, return_latents=False)
#     # play audio
#     ipd.display(ipd.Audio(audio.cpu().numpy()[0], rate=SAMPLE_RATE))