Skip to content

Harmonai-org/audio-diffusion-pytorch-fork

 
 

Repository files navigation

Unconditional audio generation using diffusion models, in PyTorch. The goal of this repository is to explore different architectures and diffusion models to generate audio (speech and music) directly from/to the waveform. Progress will be documented in the experiments section. You can use the audio-diffusion-pytorch-trainer to run your own experiments – please share your findings in the discussions page! Pretrained models can be found at archisound.

Install

pip install audio-diffusion-pytorch

PyPI - Python Version Downloads HuggingFace Open In Colab

Usage

Generation

from audio_diffusion_pytorch import AudioDiffusionModel

model = AudioDiffusionModel(in_channels=1)

# Train model with audio sources
x = torch.randn(2, 1, 2 ** 18) # [batch, in_channels, samples], 2**18 ≈ 12s of audio at a frequency of 22050
loss = model(x)
loss.backward() # Do this many times

# Sample 2 sources given start noise
noise = torch.randn(2, 1, 2 ** 18)
sampled = model.sample(
    noise=noise,
    num_steps=5 # Suggested range: 2-50
) # [2, 1, 2 ** 18]

Upsampling

from audio_diffusion_pytorch import AudioDiffusionUpsampler

upsampler = AudioDiffusionUpsampler(
    in_channels=1,
    factor=8,
)

# Train on high frequency data
x = torch.randn(2, 1, 2 ** 18)
loss = upsampler(x)
loss.backward()

# Given start undersampled source, samples upsampled source
undersampled = torch.randn(1, 1, 2 ** 15)
upsampled = upsampler.sample(
    undersampled,
    num_steps=5
) # [1, 1, 2 ** 18]

Autoencoding

from audio_diffusion_pytorch import AudioDiffusionAE

autoencoder = AudioDiffusionAE(in_channels=1)

# Train on audio samples
x = torch.randn(2, 1, 2 ** 18)
loss = autoencoder(x)
loss.backward()

# Encode audio source into latent
x = torch.randn(2, 1, 2 ** 18)
latent = autoencoder.encode(x) # [2, 32, 128]

# Decode latent by diffusion sampling
decoded = autoencoder.decode(
    latent,
    num_steps=5
) # [2, 32, 2**18]

Conditional Generation

from audio_diffusion_pytorch import AudioDiffusionConditional

model = AudioDiffusionConditional(
    in_channels=1,
    embedding_max_length=64,
    embedding_features=768,
    embedding_mask_proba=0.1 # Conditional dropout of batch elements
)

# Train on pairs of audio and embedding data (e.g. from a transformer output)
x = torch.randn(2, 1, 2 ** 18)
embedding = torch.randn(2, 64, 768)
loss = model(x, embedding=embedding)
loss.backward()

# Given start embedding and noise sample new source
embedding = torch.randn(2, 64, 768)
noise = torch.randn(2, 1, 2 ** 18)
sampled = model.sample(
    noise,
    embedding=embedding,
    embedding_scale=5.0, # Classifier-free guidance scale
    num_steps=5
) # [2, 1, 2 ** 18]

Text Conditional Generation

You can generate embeddings from text by using a pretrained frozen T5 transformer with T5Embedder, as follows (note that this requires pip install transformers):

from audio_diffusion_pytorch import T5Embedder

embedder = T5Embedder(model='t5-base', max_length=64)
embedding = embedder(["First batch item text...", "Second batch item text..."]) # [2, 64, 768]

loss = model(x, embedding=embedding)
# ...
sampled = model.sample(
    noise,
    embedding=embedding,
    embedding_scale=5.0, # Classifier-free guidance scale
    num_steps=5
)

Number Conditional Generation

from audio_diffusion_pytorch import NumberEmbedder

embedder = NumberEmbedder(features=768)
embedding = embedder([0.1, 0.2]) # [2, 768]

Usage with Components

UNet1d

from audio_diffusion_pytorch import UNet1d

# UNet used to denoise our 1D (audio) data
unet = UNet1d(
    in_channels=1,
    channels=128,
    patch_size=16,
    multipliers=[1, 2, 4, 4, 4, 4, 4],
    factors=[4, 4, 4, 2, 2, 2],
    attentions=[0, 0, 0, 1, 1, 1, 1],
    num_blocks=[2, 2, 2, 2, 2, 2],
    attention_heads=8,
    attention_features=64,
    attention_multiplier=2,
    resnet_groups=8,
    kernel_multiplier_downsample=2,
    use_nearest_upsample=False,
    use_skip_scale=True,
    use_context_time=True
)

x = torch.randn(3, 1, 2 ** 16)
t = torch.tensor([0.2, 0.8, 0.3])

y = unet(x, t) # [3, 1, 32768], compute 3 samples of ~1.5 seconds at 22050Hz with the given noise levels t

Diffusion

Training

from audio_diffusion_pytorch import KDiffusion, LogNormalDistribution
from audio_diffusion_pytorch import VDiffusion, UniformDistribution

# Either use KDiffusion
diffusion = KDiffusion(
    net=unet,
    sigma_distribution=LogNormalDistribution(mean = -3.0, std = 1.0),
    sigma_data=0.1,
    dynamic_threshold=0.0
)

# Or use VDiffusion
diffusion = VDiffusion(
    net=unet,
    sigma_distribution=UniformDistribution()
)

x = torch.randn(3, 1, 2 ** 18) # Batch of training audio samples
loss = diffusion(x)
loss.backward() # Do this many times

Sampling

from audio_diffusion_pytorch import DiffusionSampler, KarrasSchedule

sampler = DiffusionSampler(
    diffusion,
    num_steps=5, # Suggested range 2-100, higher better quality but takes longer
    sampler=ADPM2Sampler(rho=1),
    sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0)
)
# Generate a sample starting from the provided noise
y = sampler(noise = torch.randn(1,1,2 ** 18))

Inpainting

from audio_diffusion_pytorch import DiffusionInpainter, KarrasSchedule, ADPM2Sampler

inpainter = DiffusionInpainter(
    diffusion,
    num_steps=5, # Suggested range 2-100, higher for better quality
    num_resamples=1, # Suggested range 1-10, higher for better quality
    sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0),
    sampler=ADPM2Sampler(rho=1.0),
)

inpaint = torch.randn(1,1,2 ** 18) # Start track, e.g. one sampled with DiffusionSampler
inpaint_mask = torch.randint(0,2, (1,1,2 ** 18), dtype=torch.bool) # Set to `True` the parts you want to keep
y = inpainter(inpaint = inpaint, inpaint_mask = inpaint_mask)

Infinite Generation

from audio_diffusion_pytorch import SpanBySpanComposer

composer = SpanBySpanComposer(
    inpainter,
    num_spans=4 # Number of spans to inpaint after provided input
)
y_long = composer(y, keep_start=True) # [1, 1, 98304]

Experiments

Report Snapshot Description
Alpha 6bd9279f19 Initial tests on LJSpeech dataset with new architecture and basic DDPM diffusion model.
Bravo a05f30aa94 Elucidated diffusion, improved architecture with patching, longer duration, initial good (unsupervised) results on LJSpeech.
Charlie 50ecc30d70 Train on music with YoutubeDataset, larger patch tests for longer tracks, inpainting tests, initial test with infinite generation using SpanBySpanComposer.
Delta 672876bf13 Test model with the faster ADPM2 sampler and dynamic thresholding.
Echo (current) Test AudioDiffusionUpsampler.

TODO

  • Add elucidated diffusion.
  • Add ancestral DPM2 sampler.
  • Add dynamic thresholding.
  • Add (variational) autoencoder option to compress audio before diffusion (removed).
  • Fix inpainting and make it work with ADPM2 sampler.
  • Add trainer with experiments.
  • Add diffusion upsampler.
  • Add ancestral euler sampler AEulerSampler.
  • Add diffusion autoencoder.
  • Add diffusion upsampler.
  • Add autoencoder bottleneck option for quantization.
  • Add option to provide context tokens (cross attention).
  • Add conditional model with classifier-free guidance.
  • Add option to provide context features mapping.
  • Add option to change number of (cross) attention blocks.
  • Add VDiffusionn option.
  • Add flash attention.

Appreciation

Citations

DDPM

@misc{2006.11239,
Author = {Jonathan Ho and Ajay Jain and Pieter Abbeel},
Title = {Denoising Diffusion Probabilistic Models},
Year = {2020},
Eprint = {arXiv:2006.11239},
}

Diffusion inpainting

@misc{2201.09865,
Author = {Andreas Lugmayr and Martin Danelljan and Andres Romero and Fisher Yu and Radu Timofte and Luc Van Gool},
Title = {RePaint: Inpainting using Denoising Diffusion Probabilistic Models},
Year = {2022},
Eprint = {arXiv:2201.09865},
}

Diffusion weighted loss

@misc{2204.00227,
Author = {Jooyoung Choi and Jungbeom Lee and Chaehun Shin and Sungwon Kim and Hyunwoo Kim and Sungroh Yoon},
Title = {Perception Prioritized Training of Diffusion Models},
Year = {2022},
Eprint = {arXiv:2204.00227},
}

Improved UNet architecture

@misc{2205.11487,
Author = {Chitwan Saharia and William Chan and Saurabh Saxena and Lala Li and Jay Whang and Emily Denton and Seyed Kamyar Seyed Ghasemipour and Burcu Karagol Ayan and S. Sara Mahdavi and Rapha Gontijo Lopes and Tim Salimans and Jonathan Ho and David J Fleet and Mohammad Norouzi},
Title = {Photorealistic Text-to-Image Diffusion Models with Deep Language Understanding},
Year = {2022},
Eprint = {arXiv:2205.11487},
}

Elucidated diffusion

@misc{2206.00364,
Author = {Tero Karras and Miika Aittala and Timo Aila and Samuli Laine},
Title = {Elucidating the Design Space of Diffusion-Based Generative Models},
Year = {2022},
Eprint = {arXiv:2206.00364},
}

About

Audio generation using diffusion models, in PyTorch.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%