# Diffusion Auto Encoder

This is a implementation from [audio-diffusion-pytorch](https://github.com/archinetai/audio-diffusion-pytorch). 

In [40]:
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm

from audio_diffusion_pytorch import DiffusionAE, UNetV0, VDiffusion, VSampler
from audio_encoders_pytorch import MelE1d, TanhBottleneck

from src.datasets import MusicCapsDataset
from src.features import PreProcessor
from src.features.extractor import WaveformExtractor
from src.utils.data import TorchDataset
from src.utils.training import ModelCheckpoint
from src.utils.gpu import create_device
from src.utils.audio import Audio

## Data Preparation

### Generate Dataset

Each Musiccaps dataset has 10-second audios. To adapt to this model we will apply the following changes to the data:

* We will divide the data into two parts of 5.5 seconds. Each part of the audio will have two times: from 0 to 5.5 and 4.5 to 10 seconds.
* For each new audio, we will add information to the 'aspect list' and 'caption' explaining the audio track. For example '1 of 2' or '2 of 2'.

In [2]:
musiccaps_generator = MusicCapsDataset(format="mp3", crop_length=5.5)
dataset = musiccaps_generator.generate(num_proc=1)

### Preprocessing the data

For the format defined by the article, we need to cut the song to size 2**18 (approximately 5.5 seconds), so that it adapts to the network input

In [3]:
# Tip: We don't need to save waveforms.
train, test = PreProcessor(dataset, lambda dataset : WaveformExtractor(dataset, column="audio", crop_length=2**18)).get_train_test_split(path=musiccaps_generator.get_processed_folder(), save_split_sets=False)

Loading train/test indexes...


Generating train subset [Waveform]: 100%|██████████| 8712/8712 [06:23<00:00, 22.71it/s]
Generating test subset [Waveform]: 100%|██████████| 2178/2178 [01:54<00:00, 19.05it/s]


In [4]:
train = train[:5]

## Training

### Parameters

In [44]:
NUM_EPOCHS = 5
BATCH_SIZE = 8
DECODE_STEPS = 100
MODEL_NAME = "DiffusionAE"
MODEL_RESULT = "last-epoch.ckpt"
SAMPLING_RATE = 48000

### Adapting train data

We need to adapt the standard pytorch data:

* The pattern suggested in the documentation is [batch, in_channels, length]. In our case, our dataset is in the format [batch, length]. (length is the multiplication of frequency by time).
* We need to use DataLoader, an optimized implementation to access our data.
* We take the opportunity to use the gpu, if available

In [6]:
device = create_device()
transform = lambda x: x.unsqueeze(0) # (batch, length) ->  (batch, 1, length)

train_dataset = TorchDataset(train, transform=transform)
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

### Model definition

We use the same as the original example. We add an optimizer to update the parameters of the model based on the gradients computed during backpropagation.
The loss is calculated internally.

In [7]:
autoencoder = DiffusionAE(
    encoder=MelE1d(
        in_channels=1,
        channels=512,
        multipliers=[1, 1],
        factors=[2],
        num_blocks=[12],
        out_channels=32,
        mel_channels=80,
        mel_sample_rate=48000,
        mel_normalize_log=True,
        bottleneck=TanhBottleneck(),
    ),
    inject_depth=6,
    net_t=UNetV0,
    in_channels=1,
    channels=[8, 32, 64, 128, 256, 512, 512, 1024, 1024],
    factors=[1, 4, 4, 4, 2, 2, 2, 2, 2],
    items=[1, 2, 2, 2, 2, 2, 2, 4, 4],
    diffusion_t=VDiffusion,
    sampler_t=VSampler,
    )

autoencoder = autoencoder.to(device)

optimizer = torch.optim.Adam(autoencoder.parameters(), lr=0.001)

### Model Checkpoint

In [8]:
checkpoint_manager = ModelCheckpoint(MODEL_NAME)

### Load Model from Checkpoint

In [9]:
# autoencoder.load_state_dict(torch.load("CHECKPOINT_PATH", map_location=device))

### Training loop

In [10]:
start_epoch, start_batch_index, _ = checkpoint_manager.resume(autoencoder, optimizer)

for epoch in tqdm(range(start_epoch, NUM_EPOCHS), desc="Epochs"):
    with tqdm(train_dataloader, unit="i", leave=False, desc="Batches") as tepoch:
        for i, batch in enumerate(tepoch):    
            if epoch == start_epoch and i < start_batch_index:
                continue

            batch = batch.to(device)
            optimizer.zero_grad()
            loss = autoencoder(batch)
            loss.backward()
            optimizer.step()

            if (i + 1) % 10 == 0:
                checkpoint_manager.save(autoencoder, optimizer, epoch, i, loss.item())
                tepoch.set_description(f"Epoch {epoch} Batch {i+1}/{len(train_dataloader)} Loss: {loss.item():.4f} Saved")
            else:
                # Update progress
                tepoch.set_description(f"Epoch {epoch} Batch {i+1}/{len(train_dataloader)} Loss: {loss.item():.4f}")
        

Nenhum checkpoint encontrado.


Epochs: 100%|██████████| 5/5 [07:31<00:00, 90.28s/it] 


In [11]:
# Save final model
torch.save(autoencoder.state_dict(), f"models/{MODEL_NAME}/{MODEL_RESULT}")

In [30]:
test_dataset = TorchDataset(test[:1], transform=transform)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [37]:

autoencoder.eval() 
generated_audios = torch.empty(0, device=device) 

with torch.no_grad():
    for i, batch in enumerate(test_dataloader):
        batch = batch.to(device)
        latent = autoencoder.encode(batch)
        generated_audio = autoencoder.decode(latent, num_steps=DECODE_STEPS)
        generated_audios = torch.cat((generated_audios, generated_audio), dim=0)


In [55]:
Audio.save(generated_audios.cpu().numpy(), sample_rate=SAMPLING_RATE, folder_path=f"models/{MODEL_NAME}") 