In [1]:
!nvidia-smi

# If this doesn't work, there's no GPU available or detected

Mon Jul 24 15:23:24 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.182.03   Driver Version: 470.182.03   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A100 80G...  On   | 00000001:00:00.0 Off |                    0 |
| N/A   32C    P0    42W / 300W |      0MiB / 80994MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [7]:
#%pip install audiolm-pytorch
#%pip install pydub
#%pip install scikit-learn==0.24.0

Collecting scikit-learn==0.24.0
  Downloading scikit_learn-0.24.0-cp38-cp38-manylinux2010_x86_64.whl (24.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.9/24.9 MB[0m [31m42.3 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
Installing collected packages: scikit-learn
  Attempting uninstall: scikit-learn
    Found existing installation: scikit-learn 1.3.0
    Uninstalling scikit-learn-1.3.0:
      Successfully uninstalled scikit-learn-1.3.0
Successfully installed scikit-learn-0.24.0
Note: you may need to restart the kernel to use updated packages.


## Setup

### Imports & paths

In [1]:
# imports
import math
import wave
import struct
import os
import urllib.request
import tarfile
from audiolm_pytorch import SoundStream, SoundStreamTrainer, HubertWithKmeans, SemanticTransformer, SemanticTransformerTrainer, HubertWithKmeans, CoarseTransformer, CoarseTransformerWrapper, CoarseTransformerTrainer, FineTransformer, FineTransformerWrapper, FineTransformerTrainer, AudioLM
from torch import nn
import torch
import torchaudio


# define all dataset paths, checkpoints, etc
dataset_folder = "audio/wav"
soundstream_ckpt = "results/soundstream.8.pt" # this can change depending on number of steps
hubert_ckpt = 'hubert/hubert_base_ls960.pt'
hubert_quantizer = f'hubert/hubert_base_ls960_L9_km500.bin' # listed in row "HuBERT Base (~95M params)", column Quantizer

### Data

In [3]:
#from pydub import AudioSegment
#
#files = os.listdir("audio/segments")
#for f in files:
#    f = f.replace(".mp3","")
#    output_file = f"audio/wav/{f}.wav"
#    input_file = f"audio/segments/{f}.mp3"
#    sound = AudioSegment.from_mp3(input_file)
#    sound.export(output_file, format="wav")

## Training

Now that we have a dataset, we can train AudioLM.

**Note**: do NOT type "y" to overwrite previous experiments/ checkpoints when running through the cells here unless you're ready to the entire results folder! Otherwise you will end up erasing things (e.g. you train SoundStream first, and if you choose "overwrite" then you lose the SoundStream checkpoint when you then train SemanticTransformer).

### SoundStream

In [4]:
soundstream = SoundStream(
    codebook_size = 1024,
    rq_num_quantizers = 8,
)

trainer = SoundStreamTrainer(
    soundstream,
    folder = dataset_folder,
    batch_size = 4,
    grad_accum_every = 32,        # effective batch size of 32
    data_max_length_seconds = 4,  # train on 4 second audio
    save_results_every = 2,
    save_model_every = 4,
    num_train_steps = 9
).cuda()
# NOTE: I changed num_train_steps to 9 (aka 8 + 1) from 10000 to make things go faster for demo purposes
# adjusting save_*_every variables for the same reason

trainer.train()

training with dataset of 7827 samples and validating with randomly splitted 412 samples
0: soundstream total loss: 30.804, soundstream recon loss: 0.026 | discr (scale 1) loss: 1.999 | discr (scale 0.5) loss: 2.000 | discr (scale 0.25) loss: 2.000
0: saving to results
0: saving model to results
1: soundstream total loss: 25.228, soundstream recon loss: 0.012 | discr (scale 1) loss: 1.995 | discr (scale 0.5) loss: 1.998 | discr (scale 0.25) loss: 1.997
2: soundstream total loss: 27.568, soundstream recon loss: 0.004 | discr (scale 1) loss: 2.000 | discr (scale 0.5) loss: 2.000 | discr (scale 0.25) loss: 2.000
2: saving to results
3: soundstream total loss: 22.150, soundstream recon loss: 0.002 | discr (scale 1) loss: 2.009 | discr (scale 0.5) loss: 2.004 | discr (scale 0.25) loss: 2.007
4: soundstream total loss: 18.412, soundstream recon loss: 0.003 | discr (scale 1) loss: 2.012 | discr (scale 0.5) loss: 2.005 | discr (scale 0.25) loss: 2.009
4: saving to results
4: saving model to res