# Diffusion Auto Encoder

This is a implementation from [audio-diffusion-pytorch](https://github.com/archinetai/audio-diffusion-pytorch). 

In [1]:
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm

from audio_diffusion_pytorch import DiffusionAE, UNetV0, VDiffusion, VSampler
from audio_encoders_pytorch import MelE1d, TanhBottleneck

from src.datasets import MusicCapsDataset
from src.features import PreProcessor
from src.features.extractor import WaveformExtractor
from src.utils.data import TorchDataset

## Data Preparation

### Generate Dataset

Each Musiccaps dataset has 10-second audios. To adapt to this model we will apply the following changes to the data:

* We will divide the data into two parts of 5.5 seconds. Each part of the audio will have two times: from 0 to 5.5 and 4.5 to 10 seconds.
* For each new audio, we will add information to the 'aspect list' and 'caption' explaining the audio track. For example '1 of 2' or '2 of 2'.

In [None]:
musiccaps_generator = MusicCapsDataset(format="mp3", crop_length=5.5)
dataset = musiccaps_generator.generate(num_proc=1)

### Preprocessing the data

For the format defined by the article, we need to cut the song to size 2**18 (approximately 5.5 seconds), so that it adapts to the network input

In [None]:
# Tip: We don't need to save waveforms.
train, test = PreProcessor(dataset, lambda dataset : WaveformExtractor(dataset, column="audio", crop_length=2**18)).get_train_test_split(path=musiccaps_generator.get_processed_folder(), save_split_sets=False)

## Training

### Adapting train data

We need to adapt the standard pytorch data:

* The pattern suggested in the documentation is [batch, in_channels, length]. In our case, our dataset is in the format [batch, length]. (length is the multiplication of frequency by time).
* We need to use DataLoader, an optimized implementation to access our data.
* We take the opportunity to use the gpu, if available

In [None]:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size = 8
transform = lambda x: x.unsqueeze(0) # (batch, length) ->  (batch, 1, length)

train_dataset = TorchDataset(train, transform=transform)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

### Model definition

We use the same as the original example. We add an optimizer to update the parameters of the model based on the gradients computed during backpropagation.
The loss is calculated internally.

In [None]:
autoencoder = DiffusionAE(
    encoder=MelE1d(
        in_channels=1,
        channels=512,
        multipliers=[1, 1],
        factors=[2],
        num_blocks=[12],
        out_channels=32,
        mel_channels=80,
        mel_sample_rate=48000,
        mel_normalize_log=True,
        bottleneck=TanhBottleneck(),
    ),
    inject_depth=6,
    net_t=UNetV0,
    in_channels=1,
    channels=[8, 32, 64, 128, 256, 512, 512, 1024, 1024],
    factors=[1, 4, 4, 4, 2, 2, 2, 2, 2],
    items=[1, 2, 2, 2, 2, 2, 2, 4, 4],
    diffusion_t=VDiffusion,
    sampler_t=VSampler,
    )

autoencoder = autoencoder.to(device)

optimizer = torch.optim.Adam(autoencoder.parameters(), lr=0.001)

### Training loop

In [None]:
for epoch in tqdm(range(10)): # Number of epochs
    for i, batch in enumerate(train_dataloader):
        batch = batch.to(device)
        optimizer.zero_grad()
        loss = autoencoder(batch)
        loss.backward()
        optimizer.step()