In [1]:
!nvidia-smi

# imports
import math
import wave
import struct
import os
import urllib.request
import tarfile
import random
from audiolm_pytorch import SoundStream, SoundStreamTrainer, HubertWithKmeans, SemanticTransformer, SemanticTransformerTrainer, HubertWithKmeans, CoarseTransformer, CoarseTransformerWrapper, CoarseTransformerTrainer, FineTransformer, FineTransformerWrapper, FineTransformerTrainer, AudioLM
from torch import nn
import torch
import torchaudio
import ray
from random import shuffle
from fs.osfs import OSFS
from fs.mountfs import MountFS
from fs import open_fs

sample_rate = 16000

# define all dataset paths, checkpoints, etc
dataset_folder = "placeholder_dataset"
soundstream_ckpt = "results/soundstream.8.pt" # this can change depending on number of steps
hubert_ckpt = 'hubert/hubert_base_ls960.pt'
hubert_quantizer = f'hubert/hubert_base_ls960_L9_km500.bin' # listed in row "HuBERT Base (~95M params)", column Quantizer

Mon Jul 31 00:24:02 2023       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 536.67                 Driver Version: 536.67       CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                     TCC/WDDM  | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA GeForce RTX 4070 Ti   WDDM  | 00000000:2D:00.0  On |                  N/A |
|  0%   37C    P8               7W / 285W |   2607MiB / 12282MiB |      2%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

  from .autonotebook import tqdm as notebook_tqdm
2023-07-31 00:24:04 | INFO | fairseq.tasks.text_to_speech | Please install tensorboardX: pip install tensorboardX
2023-07-31 00:24:05,658	INFO util.py:159 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


In [2]:
import math
from IPython.display import Audio
import matplotlib.pyplot as plt

def plot_waveform(waveform, sample_rate, title="Waveform", xlim=None):
    waveform = waveform.numpy()

    num_channels, num_frames = waveform.shape
    time_axis = torch.arange(0, num_frames) / sample_rate

    figure, axes = plt.subplots(num_channels, 1)
    if num_channels == 1:
        axes = [axes]
    for c in range(num_channels):
        axes[c].plot(time_axis, waveform[c], linewidth=1)
        axes[c].grid(True)
        if num_channels > 1:
            axes[c].set_ylabel(f"Channel {c+1}")
        if xlim:
            axes[c].set_xlim(xlim)
    figure.suptitle(title)
    plt.show(block=False)
    
def plot_specgram(waveform, sample_rate, title="Spectrogram", xlim=None):
    waveform = waveform.numpy()

    num_channels, _ = waveform.shape

    figure, axes = plt.subplots(num_channels, 1)
    if num_channels == 1:
        axes = [axes]
    for c in range(num_channels):
        axes[c].specgram(waveform[c], Fs=sample_rate)
        if num_channels > 1:
            axes[c].set_ylabel(f"Channel {c+1}")
        if xlim:
            axes[c].set_xlim(xlim)
    figure.suptitle(title)
    plt.show(block=False)

In [3]:
ray.init(num_cpus=16, num_gpus=1)

# Setup filesystem
dataset_fs = MountFS()

#dataset_fs.mount('dns4-read', open_fs('s3://music-clip-dataset/dns/clean-fullband/read_speech/'))
dataset_fs.mount('dns4-read', open_fs('s3://music-clip-dataset/dns/clean-fullband/russian/M-AILABS_Speech_Dataset/ru_RU_47hrs_48k/female/hajdurova/chetvero_nischih/wavs/'))
[dataset_fs.mount(f'Jamendo-{_:02x}', open_fs(f's3://music-clip-dataset/clips/train/{_:02x}/')) for _ in range(4)]
[dataset_fs.mount(f'Psychostream-{_:02x}', open_fs(f's3://music-clip-dataset/psychostream/train/{_:02x}/')) for _ in range(4)]
#dataset_fs.add_fs('dns4-read', open_fs('s3://music-clip-dataset/dns/clean-fullband/read_speech/'))

# Get all music files and shuffle them
music_files = []
for path in dataset_fs.walk.files(filter=['*.mp3']):
    music_files.append(path)
    #print(path)
random.shuffle(music_files)
print(music_files[:50])

2023-07-31 00:24:08,731	INFO worker.py:1621 -- Started a local Ray instance.


['/Psychostream-02/f9aeacced9e40a71a7f01b9731c9d302.train.5.mp3', '/Jamendo-02/136a200cbf88932feed2ef3d8d25e102.train.4.mp3', '/Psychostream-01/153b89d76fe7b3ede26e625423217e01.train.6.mp3', '/Psychostream-00/a3a9a6783941229bd8fb96a996729300.train.3.mp3', '/Psychostream-01/6677c1c8c25251a82cde5f8cd66a6601.train.3.mp3', '/Psychostream-03/c39479df8287488fbf96a6a9d7ae4803.train.5.mp3', '/Psychostream-00/869b6b5a52c765445379c44867e76c00.train.7.mp3', '/Jamendo-00/ec5226131f0010e9352768a4d16d3500.train.1.mp3', '/Psychostream-02/54c2d39073f22bc307255fd4cec35302.train.9.mp3', '/Jamendo-01/59247d8f9bc3cca2c8b002445ea28801.train.4.mp3', '/Psychostream-00/1c7a39fd89dd74f7c5f8128de4104b00.train.9.mp3', '/Psychostream-03/1b3855e83c703767d638da1d6ad87903.train.4.mp3', '/Psychostream-01/194f483d30c0032027ee55d5d1905101.train.3.mp3', '/dns4-read/chetvero_nischih_s000471_seg_0.mp3', '/Jamendo-01/0f6be55f84c3beaa755153b77a58a201.train.4.mp3', '/Psychostream-03/e79b314817ffa883d446628b520c8a03.train.4.m

In [4]:
from email.mime import audio
from io import BytesIO
import os
import random
import itertools
from time import time
from einops import rearrange
import fs.info
from torchaudio.functional import resample
from torch import Tensor
import torch.nn.functional as F
from audiolm_pytorch.soundstream import cast_tuple
from audiolm_pytorch.utils import curtail_to_multiple

@ray.remote
class SeabassBatcherComputer():
    def __init__(self, clip_len:float = 2.0, sample_rate=16000):
        self.clip_len = clip_len
        self.sample_rate = sample_rate

    def getSample(self, filename):
        with self.fs.open(filename, 'rb') as f:
            reader = torchaudio.io.StreamReader(src = f)
            reader.add_basic_audio_stream(
                frames_per_chunk=self.sample_rate//10,
                stream_index=0,
                sample_rate=self.sample_rate,
            )

            audio_tensor = None
            for i, waveform in enumerate(reader.stream()):
                if audio_tensor is None:
                    audio_tensor = waveform[0]
                else:
                    audio_tensor = torch.cat((audio_tensor, waveform[0]), 0)
                
                streamed_len = (audio_tensor.shape[0] / self.sample_rate) 
                if streamed_len >= self.clip_len:
                    break # we have enough data

            audio_tensor = audio_tensor.reshape(1, -1) 
            if audio_tensor.shape[0] > 1:
                # the audio has more than 1 channel, convert to mono
                audio_tensor = torch.mean(audio_tensor, dim=0).unsqueeze(0)

            output = []
            audio_length = audio_tensor.size(1)

            ## pad or curtail
            #max_length = 99999999999999
            #if audio_length > max_length:
            #    max_start = audio_length - max_length
            #    start = torch.randint(0, max_start, (1, ))
            #    audio_tensor = audio_tensor[:, start:start + max_length]
            #else:
            #    audio_tensor = F.pad(audio_tensor, (0, max_length - audio_length), 'constant')

            audio_tensor = rearrange(audio_tensor, '1 ... -> ...')
            output = audio_tensor.float()
            return output

class SeabassIterableDataset(torch.utils.data.IterableDataset):
    def __init__(self, start, end, files: list[str], fs: MountFS, effects, clip_len:float = 2.0, tick_char='.', sample_rate=16000, sample_buffer_size=128):
        super(SeabassIterableDataset).__init__()
        self.start = start
        self.end = end
        self.effects = effects
        self.files = files
        self.fs = fs
        self.clip_len = clip_len
        self.tick_char = tick_char
        self.sample_rate = sample_rate
        self.sample_buffer_size = sample_buffer_size
        self.batcher = SeabassBatcherComputer.remote(clip_len, sample_rate)
        self.pending_samples = []
        
    # Start producing sasmples
    def run(self):
        self.pending_samples.extend(self.batcher.getSample.remote(self.files[i]) for i in range(self.sample_buffer_size))

    def getSample(self, index):
        ready_refs, self.pending_samples = ray.wait(self.pending_samples, num_returns=1)
        sample = ray.get(*ready_refs)
        ready_refs.append(self.batcher.getSample.remote(self.files[index + self.sample_buffer_size]))
        return sample
    
    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is None:  # single-process data loading, return the full iterator
            iter_start = self.start
            iter_end = self.end
        else:  # in a worker process
            # split workload
            per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
            worker_id = worker_info.id
            iter_start = self.start + worker_id * per_worker
            iter_end = min(iter_start + per_worker, self.end)

        for i in range(iter_start, iter_end):
            sample = self.getSample(i)
            yield sample

# Setup dataloader
effects = [
    ["norm", "1"],
]

needed_clip_len = float(320 * 40)/sample_rate

train_subset = music_files[:-1000]
val_subset = music_files[-1000:]

ds_t = SeabassIterableDataset(start=0, end=len(train_subset), files=train_subset, 
                            fs=dataset_fs, effects=effects, clip_len=needed_clip_len)

#ds_v = SeabassIterableDataset(start=0, end=len(val_subset), files=val_subset, 
#                            fs=dataset_fs, effects=effects, clip_len=2.0, tick_char='*')

ds_t.run()
sample = ds_t.getSample(0)
#time.sleep(3)
#ds_v.run()


2023-07-31 00:25:00,785	ERROR worker.py:405 -- Unhandled error (suppress with 'RAY_IGNORE_UNHANDLED_ERRORS=1'): [36mray::SeabassBatcherComputer.getSample()[39m (pid=16192, ip=127.0.0.1, actor_id=05936a2a2189d4359735504301000000, repr=<__main__.SeabassBatcherComputer object at 0x0000023D88A85460>)
  File "python\ray\_raylet.pyx", line 1424, in ray._raylet.execute_task
  File "python\ray\_raylet.pyx", line 1364, in ray._raylet.execute_task.function_executor
  File "c:\Users\sebas\scoop\apps\mambaforge\current\envs\audiolm\lib\site-packages\ray\_private\function_manager.py", line 726, in actor_method_executor
    return method(__ray_actor, *args, **kwargs)
  File "c:\Users\sebas\scoop\apps\mambaforge\current\envs\audiolm\lib\site-packages\ray\util\tracing\tracing_helper.py", line 464, in _resume_span
    return method(self, *_args, **_kwargs)
  File "C:\Users\sebas\AppData\Local\Temp\ipykernel_41584\1530361082.py", line 22, in getSample
    with self.fs.open(filename, 'rb') as f:
Attrib

RayTaskError(AttributeError): [36mray::SeabassBatcherComputer.getSample()[39m (pid=16192, ip=127.0.0.1, actor_id=05936a2a2189d4359735504301000000, repr=<__main__.SeabassBatcherComputer object at 0x0000023D88A85460>)
  File "python\ray\_raylet.pyx", line 1424, in ray._raylet.execute_task
  File "python\ray\_raylet.pyx", line 1364, in ray._raylet.execute_task.function_executor
  File "c:\Users\sebas\scoop\apps\mambaforge\current\envs\audiolm\lib\site-packages\ray\_private\function_manager.py", line 726, in actor_method_executor
    return method(__ray_actor, *args, **kwargs)
  File "c:\Users\sebas\scoop\apps\mambaforge\current\envs\audiolm\lib\site-packages\ray\util\tracing\tracing_helper.py", line 464, in _resume_span
    return method(self, *_args, **_kwargs)
  File "C:\Users\sebas\AppData\Local\Temp\ipykernel_41584\1530361082.py", line 22, in getSample
    with self.fs.open(filename, 'rb') as f:
AttributeError: 'SeabassBatcherComputer' object has no attribute 'fs'

In [None]:
# Test the sampler
#dl_train.sampler
#sampler = ds_train.__iter__()
#sample = sampler.__next__()
#print(sample)
#print(waveform1.shape, sr1)
#plot_waveform(waveform1, sr1, title="Train", xlim=(-0.1, 3.2))
#plot_specgram(waveform1, sr1, title="Train", xlim=(0, 3.04))
#Audio(waveform1, rate=sr1)
#
#_, waveform2, sr2 = ds_val.__iter__()
#print(waveform2.shape, sr2)
#plot_waveform(waveform2, sr2, title="Val", xlim=(-0.1, 3.2))
#plot_specgram(waveform2, sr2, title="Val", xlim=(0, 3.04))
#Audio(waveform2, rate=sr2)


### SoundStream

In [None]:
soundstream = SoundStream(
    codebook_size = 1024,
    rq_num_quantizers = 8,
)

trainer = SoundStreamTrainer(
    soundstream,
    #folder = dataset_folder,
    #train_dataloader=dl_train,
    #val_dataloader=dl_val,
    train_dataset=dst,
    val_dataset=dsv,
    #lr = 0.001,
    batch_size = 12,
    grad_accum_every = 8,         # effective batch size of 32
    data_max_length = 320 * 32,
    save_results_every = 20,
    save_model_every = 100,
    num_train_steps = 2000,
).cuda()
# NOTE: I changed num_train_steps to 9 (aka 8 + 1) from 10000 to make things go faster for demo purposes
# adjusting save_*_every variables for the same reason

trainer.train()

## Inference

In [None]:
# Everything together


In [None]:
output_path = "out.wav"
sample_rate = 44100
torchaudio.save(output_path, generated_wav.cpu(), sample_rate)