In [1]:
!nvidia-smi

# imports
import math
import wave
import struct
import os
import urllib.request
import tarfile
import random
from audiolm_pytorch import SoundStream, SoundStreamTrainer, HubertWithKmeans, SemanticTransformer, SemanticTransformerTrainer, HubertWithKmeans, CoarseTransformer, CoarseTransformerWrapper, CoarseTransformerTrainer, FineTransformer, FineTransformerWrapper, FineTransformerTrainer, AudioLM
from torch import nn
import torch
import torchaudio
from random import shuffle
from fs.osfs import OSFS
from fs.mountfs import MountFS
from fs import open_fs

sample_rate = 16000

# define all dataset paths, checkpoints, etc
dataset_folder = "placeholder_dataset"
soundstream_ckpt = "results/soundstream.8.pt" # this can change depending on number of steps
hubert_ckpt = 'hubert/hubert_base_ls960.pt'
hubert_quantizer = f'hubert/hubert_base_ls960_L9_km500.bin' # listed in row "HuBERT Base (~95M params)", column Quantizer

Sat Jul 29 22:49:10 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.116.04   Driver Version: 525.116.04   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA RTX A5000    Off  | 00000000:00:05.0 Off |                  Off |
| 30%   31C    P8    15W / 230W |      1MiB / 24564MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

2023-07-29 22:49:12 | INFO | fairseq.tasks.text_to_speech | Please install tensorboardX: pip install tensorboardX


In [2]:
import math
from IPython.display import Audio
import matplotlib.pyplot as plt

def plot_waveform(waveform, sample_rate, title="Waveform", xlim=None):
    waveform = waveform.numpy()

    num_channels, num_frames = waveform.shape
    time_axis = torch.arange(0, num_frames) / sample_rate

    figure, axes = plt.subplots(num_channels, 1)
    if num_channels == 1:
        axes = [axes]
    for c in range(num_channels):
        axes[c].plot(time_axis, waveform[c], linewidth=1)
        axes[c].grid(True)
        if num_channels > 1:
            axes[c].set_ylabel(f"Channel {c+1}")
        if xlim:
            axes[c].set_xlim(xlim)
    figure.suptitle(title)
    plt.show(block=False)
    
def plot_specgram(waveform, sample_rate, title="Spectrogram", xlim=None):
    waveform = waveform.numpy()

    num_channels, _ = waveform.shape

    figure, axes = plt.subplots(num_channels, 1)
    if num_channels == 1:
        axes = [axes]
    for c in range(num_channels):
        axes[c].specgram(waveform[c], Fs=sample_rate)
        if num_channels > 1:
            axes[c].set_ylabel(f"Channel {c+1}")
        if xlim:
            axes[c].set_xlim(xlim)
    figure.suptitle(title)
    plt.show(block=False)

In [3]:
# Setup filesystem
dataset_fs = MountFS()

#dataset_fs.mount('dns4-read', open_fs('s3://music-clip-dataset/dns/clean-fullband/read_speech/'))
dataset_fs.mount('dns4-read', open_fs('s3://music-clip-dataset/dns/clean-fullband/russian/M-AILABS_Speech_Dataset/ru_RU_47hrs_48k/female/hajdurova/chetvero_nischih/wavs/'))
[dataset_fs.mount(f'Jamendo-{_:02x}', open_fs(f's3://music-clip-dataset/clips/train/{_:02x}/')) for _ in range(4)]
[dataset_fs.mount(f'Psychostream-{_:02x}', open_fs(f's3://music-clip-dataset/psychostream/train/{_:02x}/')) for _ in range(4)]
#dataset_fs.add_fs('dns4-read', open_fs('s3://music-clip-dataset/dns/clean-fullband/read_speech/'))

# Get all music files and shuffle them
music_files = []
for path in dataset_fs.walk.files(filter=['*.mp3']):
    music_files.append(path)
    #print(path)
random.shuffle(music_files)
print(music_files[:50])

['/Psychostream-02/e0d31d09f76c43529b772c8ce6b80a02.train.6.mp3', '/Psychostream-03/83db3b5095059d78d798d82491b45d03.train.4.mp3', '/Jamendo-00/bfc4a787b6601b1f300d1357642d1200.train.6.mp3', '/Jamendo-02/fff6addacb70493044227f2d0bad2702.train.7.mp3', '/Psychostream-00/8d7816bb4379e5583ca3fa8012ce1b00.train.4.mp3', '/Jamendo-00/8a5a0bf285df6a6d2991580b8b827c00.train.4.mp3', '/Jamendo-03/37f0f7c8a59df5e1a177d4f944252403.train.10.mp3', '/Psychostream-03/061bff76a9335d453676abbc2b16a203.train.4.mp3', '/Jamendo-03/347de7fa20c4e5882ebb2e9df0f08403.train.11.mp3', '/dns4-read/chetvero_nischih_s000806.mp3', '/Psychostream-00/24d453f51dbf6374ebe4ba0ce8490900.train.6.mp3', '/Psychostream-02/92f71d5ad0abf2664af3add26f410302.train.3.mp3', '/Psychostream-03/c1e5439009422152ec89bd70b810d303.train.6.mp3', '/Psychostream-00/6e659208e04142e224a321caf6424700.train.12.mp3', '/Psychostream-03/fa84620c49b5d423270f10c9e8bb1703.train.0.mp3', '/Psychostream-02/2f57ca5f405f5d486bb0ba8061f5f702.train.6.mp3', '/P

In [4]:
from email.mime import audio
from io import BytesIO
import os
import random
import itertools
from einops import rearrange
import fs.info
from torchaudio.functional import resample
from torch import Tensor
import torch.nn.functional as F
from audiolm_pytorch.soundstream import cast_tuple
from audiolm_pytorch.utils import curtail_to_multiple

class SeabassIterableDataset(torch.utils.data.IterableDataset):
    def __init__(self, start, end, files: list[str], fs: MountFS, effects, clip_len:float = 2.0, tick_char='.', sample_rate=16000):
        super(SeabassIterableDataset).__init__()
        self.start = start
        self.end = end
        self.effects = effects
        self.files = files
        self.fs = fs
        self.clip_len = clip_len
        self.tick_char = tick_char
        self.sample_rate = sample_rate

    def getSample(self, index):
        global sample_rate
        
        filename = self.files[index]
        #data = BytesIO(self.fs.readbytes(filename))
        with self.fs.open(filename, 'rb') as f:
            reader = torchaudio.io.StreamReader(src = f)
            reader.add_basic_audio_stream(
                frames_per_chunk=self.sample_rate//10,
                stream_index=0,
                sample_rate=self.sample_rate,
            )

            audio_tensor = None
            for i, waveform in enumerate(reader.stream()):
                if audio_tensor is None:
                    audio_tensor = waveform[0]
                else:
                    audio_tensor = torch.cat((audio_tensor, waveform[0]), 0)
                
                streamed_len = (audio_tensor.shape[0] / self.sample_rate) 
                if streamed_len >= self.clip_len:
                    break # we have enough data
               
            audio_tensor = audio_tensor.reshape(1, -1) 
            if audio_tensor.shape[0] > 1:
                # the audio has more than 1 channel, convert to mono
                audio_tensor = torch.mean(audio_tensor, dim=0).unsqueeze(0)

            #num_outputs = 1
            #audio_tensor = cast_tuple(audio_tensor, num_outputs)
            #data_tuple = tuple((resample(d, sample_hz, target_sample_hz) if exists(target_sample_hz) else d) for d, target_sample_hz in zip(data, self.target_sample_hz))

            output = []
            audio_length = audio_tensor.size(1)

            ## pad or curtail
            #max_length = 99999999999999
            #if audio_length > max_length:
            #    max_start = audio_length - max_length
            #    start = torch.randint(0, max_start, (1, ))
            #    audio_tensor = audio_tensor[:, start:start + max_length]
            #else:
            #    audio_tensor = F.pad(audio_tensor, (0, max_length - audio_length), 'constant')

            audio_tensor = rearrange(audio_tensor, '1 ... -> ...')
            output = audio_tensor.float()

            print(self.tick_char, end='', flush=True)
            return output

    
    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is None:  # single-process data loading, return the full iterator
            iter_start = self.start
            iter_end = self.end
        else:  # in a worker process
            # split workload
            per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
            worker_id = worker_info.id
            iter_start = self.start + worker_id * per_worker
            iter_end = min(iter_start + per_worker, self.end)

        for i in range(iter_start, iter_end):
            sample = self.getSample(i)
            yield sample

# Setup dataloader
effects = [
    ["norm", "1"],
]

needed_clip_len = float(320 * 40)/sample_rate

train_subset = music_files[:-1000]
val_subset = music_files[-1000:]

dst = SeabassIterableDataset(start=0, end=len(train_subset), files=train_subset, 
                            fs=dataset_fs, effects=effects, clip_len=needed_clip_len)
dsv = SeabassIterableDataset(start=0, end=len(val_subset), files=val_subset, 
                            fs=dataset_fs, effects=effects, clip_len=2.0, tick_char='*')


In [5]:
# Test the sampler
#dl_train.sampler
#sampler = ds_train.__iter__()
#sample = sampler.__next__()
#print(sample)
#print(waveform1.shape, sr1)
#plot_waveform(waveform1, sr1, title="Train", xlim=(-0.1, 3.2))
#plot_specgram(waveform1, sr1, title="Train", xlim=(0, 3.04))
#Audio(waveform1, rate=sr1)
#
#_, waveform2, sr2 = ds_val.__iter__()
#print(waveform2.shape, sr2)
#plot_waveform(waveform2, sr2, title="Val", xlim=(-0.1, 3.2))
#plot_specgram(waveform2, sr2, title="Val", xlim=(0, 3.04))
#Audio(waveform2, rate=sr2)


### SoundStream

In [None]:
soundstream = SoundStream(
    codebook_size = 1024,
    rq_num_quantizers = 8,
)

trainer = SoundStreamTrainer(
    soundstream,
    #folder = dataset_folder,
    #train_dataloader=dl_train,
    #val_dataloader=dl_val,
    train_dataset=dst,
    val_dataset=dsv,
    #lr = 0.001,
    batch_size = 12,
    grad_accum_every = 8,         # effective batch size of 32
    data_max_length = 320 * 32,
    save_results_every = 20,
    save_model_every = 100,
    num_train_steps = 2000,
).cuda()
# NOTE: I changed num_train_steps to 9 (aka 8 + 1) from 10000 to make things go faster for demo purposes
# adjusting save_*_every variables for the same reason

trainer.train()

............................................................................................................................................................................................................0: soundstream total loss: 59.463, soundstream recon loss: 0.044 | discr (scale 1) loss: 2.000 | discr (scale 0.5) loss: 2.000 | discr (scale 0.25) loss: 2.000
************0: saving to results
0: saving model to results
................................................................................................................................................................................................1: soundstream total loss: 62.514, soundstream recon loss: 0.045 | discr (scale 1) loss: 1.997 | discr (scale 0.5) loss: 1.997 | discr (scale 0.25) loss: 1.998
................................................................................................................................................................................................2: soundstream total loss: 59.66

## Inference

In [None]:
# Everything together


In [None]:
output_path = "out.wav"
sample_rate = 44100
torchaudio.save(output_path, generated_wav.cpu(), sample_rate)