In [8]:
!nvidia-smi

# If this doesn't work, there's no GPU available or detected

Wed Feb  8 20:56:19 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 515.65.01    Driver Version: 515.65.01    CUDA Version: 11.7     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A10          On   | 00000000:06:00.0 Off |                    0 |
|  0%   30C    P8    21W / 150W |      2MiB / 23028MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [3]:
# !pip install audiolm-pytorch boto3 tensorboardX
!pip install boto3 tensorboardX
!pip install audiolm-pytorch
# !pip uninstall -y audiolm-pytorch
# raise AssertionError("don't forget to put in your patched version of audiolm and aws credentials!")
# tensorboardX required for lambda labs

Defaulting to user installation because normal site-packages is not writeable

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip available: [0m[31;49m22.3[0m[39;49m -> [0m[32;49m23.0[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3 -m pip install --upgrade pip[0m
Defaulting to user installation because normal site-packages is not writeable
Collecting audiolm-pytorch==0.11.1
  Downloading audiolm_pytorch-0.11.1-py3-none-any.whl (28 kB)
Collecting ema-pytorch
  Downloading ema_pytorch-0.1.4-py3-none-any.whl (4.2 kB)
Collecting transformers
  Downloading transformers-4.26.0-py3-none-any.whl (6.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.3/6.3 MB[0m [31m192.7 MB/s[0m eta [36m0:00:00[0m
Collecting fairseq
  Downloading fairseq-0.12.2-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl (11.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m11.0/11.0 MB[0m [31m79.8 MB/

## Setup

Includes:

- How to generate a placeholder dataset if you haven't already, just the basics to run "training" e2e on a tiny dataset
- How to download a dataset from OpenSLR

In [19]:
# from google.colab import drive
# drive.mount('/content/drive/')

# %set_env AWS_SHARED_CREDENTIALS_FILE=drive/MyDrive/Colab Notebooks/AWS-SECRET-credentials
# %set_env AWS_CONFIG_FILE=drive/MyDrive/Colab Notebooks/AWS-config


### Imports & paths

In [4]:
# imports
import math
import wave
import struct
import os
import urllib.request
import tarfile
from audiolm_pytorch import SoundStream, SoundStreamTrainer, HubertWithKmeans, SemanticTransformer, SemanticTransformerTrainer, HubertWithKmeans, CoarseTransformer, CoarseTransformerWrapper, CoarseTransformerTrainer, FineTransformer, FineTransformerWrapper, FineTransformerTrainer, AudioLM
from torch import nn
import torch
import torchaudio
import boto3
import datetime
from botocore.errorfactory import ClientError


# define all dataset paths, checkpoints, etc
dataset_folder = "placeholder_dataset"
soundstream_ckpt = "soundstream_results/soundstream.8.pt" # this can change depending on number of steps
hubert_ckpt = 'hubert/hubert_base_ls960.pt'
hubert_quantizer = f'hubert/hubert_base_ls960_L9_km500.bin' # listed in row "HuBERT Base (~95M params)", column Quantizer

  from pandas.core.computation.check import NUMEXPR_INSTALLED


### Data

In [21]:
# Placeholder data generation
def get_sinewave(freq=440.0, duration_ms=200, volume=1.0, sample_rate=44100.0):
  # code adapted from https://stackoverflow.com/a/33913403
  audio = []
  num_samples = duration_ms * (sample_rate / 1000.0)
  for x in range(int(num_samples)):
    audio.append(volume * math.sin(2 * math.pi * freq * (x / sample_rate)))
  return audio

def save_wav(file_name, audio, sample_rate=44100.0):
  # Open up a wav file
  wav_file=wave.open(file_name,"w")
  # wav params
  nchannels = 1
  sampwidth = 2
  # 44100 is the industry standard sample rate - CD quality.  If you need to
  # save on file size you can adjust it downwards. The stanard for low quality
  # is 8000 or 8kHz.
  nframes = len(audio)
  comptype = "NONE"
  compname = "not compressed"
  wav_file.setparams((nchannels, sampwidth, sample_rate, nframes, comptype, compname))
  # WAV files here are using short, 16 bit, signed integers for the 
  # sample size.  So we multiply the floating point data we have by 32767, the
  # maximum value for a short integer.  NOTE: It is theortically possible to
  # use the floating point -1.0 to 1.0 data directly in a WAV file but not
  # obvious how to do that using the wave module in python.
  for sample in audio:
      wav_file.writeframes(struct.pack('h', int( sample * 32767.0 )))
  wav_file.close()
  return

def make_placeholder_dataset():
  # Make a placeholder dataset with a few .wav files that you can "train" on, just to verify things work e2e
  if os.path.isdir(dataset_folder):
    return
  os.makedirs(dataset_folder)
  save_wav(f"{dataset_folder}/example.wav", get_sinewave())
  save_wav(f"{dataset_folder}/example2.wav", get_sinewave(duration_ms=500))
  os.makedirs(f"{dataset_folder}/subdirectory")
  save_wav(f"{dataset_folder}/subdirectory/example.wav", get_sinewave(freq=330.0))

make_placeholder_dataset()

In [5]:
# Get actual dataset. Uncomment this if you want to try training on real data

# full dataset: https://www.openslr.org/12
# We'll use https://us.openslr.org/resources/12/dev-clean.tar.gz development set, "clean" speech.
# We *should* train on, well, training, but this is just to demo running things end-to-end at all so I just picked a small clean set.

url = "https://us.openslr.org/resources/12/dev-clean.tar.gz"
filename = "dev-clean"
filename_targz = filename + ".tar.gz"
if not os.path.isfile(filename_targz):
  urllib.request.urlretrieve(url, filename_targz)
if not os.path.isdir(filename):
  # open file
  with tarfile.open(filename_targz) as t:
    t.extractall(filename)
dataset_folder = filename # update dataset_folder so we use the right dataset

### S3 data

In [6]:
# Helper function taking results_folder, save_every, num_train_steps to periodically write to S3
boto3.setup_default_session(profile_name='paperspace')
s3_resource = boto3.resource('s3')
bucket_name = "itsleonwu-paperspace"
filename = "sample_data/README.md"
uploaded_filename = f"{filename}_upload"
downloaded_filename = f"{filename}_download"

bucket = s3_resource.Bucket(name=bucket_name)
s3 = boto3.client('s3') # I guess this is easier for s3 to check if a file exists?? argh
 
# bucket.upload_file(
#     Filename=filename, Key=uploaded_filename)

# obj = bucket.Object(uploaded_filename)
# obj.download_file(downloaded_filename) 

def add_trailing_slash(folder):
  if folder[-1] != "/":
    return folder + "/"
  return folder

def write_folder_to_s3(folder):
  """write all contents of folder (that don't already exist in s3) to s3."""
  folder = add_trailing_slash(folder)
  filenames = {object_summary.key.split(folder)[1] for object_summary in bucket.objects.filter(Prefix=folder)}
  local_files = set(os.listdir(folder))
  if filenames.intersection(local_files) != set():
    raise AssertionError(f"Found files in local that already exist in s3 bucket. suspicious, probably old training run that wasn't cleaned up properly: {filenames.intersection(local_files)}")
  for file in local_files - filenames:
    print(f"writing file: {folder}{file}")
    bucket.upload_file(Filename=f"{folder}{file}", Key=f"{folder}{file}")

def trainer_loop(trainer, results_folder, save_every, num_train_steps, num_steps_so_far=0):
  """runs trainer loop, sending off checkpoints to s3 when you make them.
  If restarting from a given checkpoint, set num_steps_so_far.
  num_train_steps is the number of ADDITIONAL training steps.
  
  For simplicity, just going to have save_model_every match save_results_every, even though they can go separately."""
  results_folder = add_trailing_slash(results_folder)
  if add_trailing_slash(trainer.results_folder) == results_folder:
    raise AssertionError("temp hack, let's make results_folder be different "
    "from trainer.results_folder (the former we use for putting the actual "
    "checkpoints.")
  if not os.path.exists(results_folder):
    os.mkdir(results_folder)

  num_checkpoint_cycles = int(num_train_steps / save_every)
  if num_train_steps % save_every != 1:
    raise AssertionError(
        f"please make remainder 1 to make sure it saves and we don't waste "
        f"compute-- we got save_every {save_every} and num_train_steps "
        f"{num_train_steps} so mod is {num_train_steps % save_every}")
  if num_steps_so_far == 0:
    # save the initial as a checkpoint before doing any training as well
    trainer.save(os.path.join(results_folder, "0.pt"))
    write_folder_to_s3(results_folder)
  else:
    # download the checkpoint back from s3.
    download_filename = os.path.join(results_folder, f"{num_steps_so_far}.pt")
    obj = bucket.Object(download_filename)
    obj.download_file(download_filename)

  for checkpoint_i in range(num_checkpoint_cycles):
    # load previous checkpoint.
    prev_checkpoint_filename = f"{checkpoint_i * save_every + num_steps_so_far}.pt"
    trainer.load(os.path.join(results_folder, prev_checkpoint_filename))
    # Then we delete locally, cleaning checkpoints up- should already exist in s3 by this point anyhow
    try:
      s3.head_object(Bucket=bucket_name, Key=f"{results_folder}{prev_checkpoint_filename}") # assert exists in s3
    except ClientError as e:
      error_code = int(e.response['Error']['Code'])
      if error_code == 404:
        raise AssertionError(f'File {results_folder}{prev_checkpoint_filename} does not exist')
      else:
        raise AssertionError(f"got error {e} with error code {error_code}")
    os.remove(os.path.join(results_folder, prev_checkpoint_filename))

    # (checkpoint_i + 1) since e.g. for checkpoint_i == 0, we've just actually run the first save_every train steps and it's really save_every.pt not 0.pt
    checkpoint_filename = f"{(checkpoint_i + 1) * save_every + num_steps_so_far}.pt"
    # this is a hack to make the trainer think it is essentially starting the
    # current run "from scratch", since we hack num_train_steps to be a small
    # number to force it to save early
    trainer.register_buffer('steps', torch.Tensor([0]))
    trainer.num_train_steps = save_every + 1 # + 1 ensures we save the checkpoint
    trainer.train()
    trainer.save(os.path.join(results_folder, checkpoint_filename))
    for file in os.listdir(trainer.results_folder):
      # delete temporary checkpoints saved automatically by trainer but keep the folder
      os.remove(os.path.join(trainer.results_folder, file))

    # # At this point we'll have updated checkpoint in folder, but we want to rename them all so the steps is correct.
    # for file in os.listdir(results_folder):
    #   # should contain a bunch of checkpoints of the format prefix1.prefix2.prefix3...{steps}.pt
    #   file_split_on_dots = file.split(".")
    #   file_split_on_dots[-2] = str(checkpoint_i * save_every) # the actual correct number of steps that were taken
    #   original_full_path = os.path.join(results_folder, file)
    #   new_full_path = os.path.join(results_folder, '.'.join(file_split_on_dots))
    #   os.rename(original_full_path, new_full_path)
    s3_start = datetime.datetime.now()
    write_folder_to_s3(results_folder)
    print(f"write_to_s3_time: {(datetime.datetime.now() - s3_start).total_seconds()}")



## Training

Now that we have a dataset, we can train AudioLM.

**Note**: do NOT type "y" to overwrite previous experiments/ checkpoints when running through the cells here unless you're ready to the entire results folder! Otherwise you will end up erasing things (e.g. you train SoundStream first, and if you choose "overwrite" then you lose the SoundStream checkpoint when you then train SemanticTransformer).

### SoundStream

In [37]:
# !cp -r dev-clean/LibriSpeech/dev-clean/1272/128104 .
# !du -sh 128104 # 2.8 MB
# dataset_folder = "128104"

# https://github.com/lucidrains/audiolm-pytorch/archive/refs/heads/main.zip

In [16]:
# cleanup
!rm -r soundstream_results actual_soundstream_results
!mkdir soundstream_results actual_soundstream_results
bucket.objects.all().delete()

[{'ResponseMetadata': {'RequestId': '5YSHBQ79J893699Q',
   'HostId': '1YAewsUC23/YqzEEYzIwNOaZb/0zfGHhRX6iVxCrBeU2OpebjlhPB4ClJ5//sUJml+DVIoXMMGM=',
   'HTTPStatusCode': 200,
   'HTTPHeaders': {'x-amz-id-2': '1YAewsUC23/YqzEEYzIwNOaZb/0zfGHhRX6iVxCrBeU2OpebjlhPB4ClJ5//sUJml+DVIoXMMGM=',
    'x-amz-request-id': '5YSHBQ79J893699Q',
    'date': 'Wed, 08 Feb 2023 21:04:24 GMT',
    'content-type': 'application/xml',
    'transfer-encoding': 'chunked',
    'server': 'AmazonS3',
    'connection': 'close'},
   'RetryAttempts': 0},
  'Deleted': [{'Key': 'actual_soundstream_results/0.pt'},
   {'Key': 'actual_soundstream_results/4.pt'},
   {'Key': 'actual_soundstream_results/8.pt'},
   {'Key': 'actual_soundstream_results/2.pt'},
   {'Key': 'actual_soundstream_results/6.pt'}]}]

In [None]:
from audiolm_pytorch import AudioLMSoundStream

soundstream = AudioLMSoundStream(
    codebook_size = 1024,
    rq_num_quantizers = 8,
    attn_window_size = 128,       # local attention receptive field at bottleneck
    attn_depth = 2                # 2 local attention transformer blocks - the soundstream folks were not experts with attention, so i took the liberty to add some. encodec went with lstms, but attention should be better
)

actual_num_train_steps = 20001
save_every = 1000

trainer = SoundStreamTrainer(
    soundstream,
    folder = dataset_folder,
    lr=3e-4,
    batch_size = 4,
    grad_accum_every = 8, # effective batch size of batch_size * grad_accum_every = 32
    data_max_length_seconds = 2,  # train on 2 second audio
    results_folder = "soundstream_results",
    save_results_every = save_every,
    save_model_every = save_every,
    num_train_steps = actual_num_train_steps
).cuda()

trainer_loop_start = datetime.datetime.now()
# need this extra num_train_steps param because we overwrite trainer.num_train_steps later lol
trainer_loop(trainer,
             "actual_soundstream_results",
             save_every=save_every,
             num_train_steps=actual_num_train_steps,
             num_steps_so_far=0)
print(f"trainer_loop_time for {actual_num_train_steps} steps: "
      f"{(datetime.datetime.now() - trainer_loop_start).total_seconds()}")

training with dataset of 2567 samples and validating with randomly splitted 136 samples
writing file: actual_soundstream_results/0.pt
0: soundstream total loss: 1061894160.000, soundstream recon loss: 0.958 | discr (scale 1) loss: 1.999 | discr (scale 0.5) loss: 2.003 | discr (scale 0.25) loss: 2.003
0: saving to soundstream_results
0: saving model to soundstream_results
1: soundstream total loss: 868120560.000, soundstream recon loss: 0.472 | discr (scale 1) loss: 1.961 | discr (scale 0.5) loss: 1.973 | discr (scale 0.25) loss: 1.981
2: soundstream total loss: 962052024.000, soundstream recon loss: 0.829 | discr (scale 1) loss: 1.851 | discr (scale 0.5) loss: 1.885 | discr (scale 0.25) loss: 1.924
3: soundstream total loss: 1125865040.000, soundstream recon loss: 0.796 | discr (scale 1) loss: 1.906 | discr (scale 0.5) loss: 1.841 | discr (scale 0.25) loss: 1.887
4: soundstream total loss: 667848528.000, soundstream recon loss: 0.303 | discr (scale 1) loss: 1.877 | discr (scale 0.5) lo

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



396: soundstream total loss: 20077035.000, soundstream recon loss: 0.012 | discr (scale 1) loss: 0.061 | discr (scale 0.5) loss: 0.227 | discr (scale 0.25) loss: 0.248
397: soundstream total loss: 20493191.750, soundstream recon loss: 0.007 | discr (scale 1) loss: 0.027 | discr (scale 0.5) loss: 0.268 | discr (scale 0.25) loss: 0.196
398: soundstream total loss: 22321013.750, soundstream recon loss: 0.007 | discr (scale 1) loss: 0.028 | discr (scale 0.5) loss: 0.142 | discr (scale 0.25) loss: 0.084
399: soundstream total loss: 21620670.500, soundstream recon loss: 0.009 | discr (scale 1) loss: 0.038 | discr (scale 0.5) loss: 0.170 | discr (scale 0.25) loss: 0.128
400: soundstream total loss: 19487751.875, soundstream recon loss: 0.007 | discr (scale 1) loss: 0.211 | discr (scale 0.5) loss: 0.294 | discr (scale 0.25) loss: 0.202
401: soundstream total loss: 19983471.875, soundstream recon loss: 0.012 | discr (scale 1) loss: 0.070 | discr (scale 0.5) loss: 0.183 | discr (scale 0.25) loss

### SemanticTransformer

In [4]:
raise AssertionError("don't run this in trainer loop until you've fixed trainer_loop to accomodate non-soundstream checkpoints, see filename")
# hubert checkpoints can be downloaded at
# https://github.com/facebookresearch/fairseq/tree/main/examples/hubert
if not os.path.isdir("hubert"):
  os.makedirs("hubert")
if not os.path.isfile(hubert_ckpt):
  hubert_ckpt_download = f"https://dl.fbaipublicfiles.com/{hubert_ckpt}"
  urllib.request.urlretrieve(hubert_ckpt_download, f"./{hubert_ckpt}")
if not os.path.isfile(hubert_quantizer):
  hubert_quantizer_download = f"https://dl.fbaipublicfiles.com/{hubert_quantizer}"
  urllib.request.urlretrieve(hubert_quantizer_download, f"./{hubert_quantizer}")

wav2vec = HubertWithKmeans(
    checkpoint_path = f'./{hubert_ckpt}',
    kmeans_path = f'./{hubert_quantizer}'
)

semantic_transformer = SemanticTransformer(
    num_semantic_tokens = wav2vec.codebook_size,
    dim = 1024,
    depth = 6
).cuda()


trainer = SemanticTransformerTrainer(
    transformer = semantic_transformer,
    wav2vec = wav2vec,
    folder = dataset_folder,
    batch_size = 1,
    data_max_length = 320 * 32,
    num_train_steps = 1,
    results_folder = "./semantic_results",
)

trainer.train()

AssertionError: don't run this in trainer loop until you've fixed trainer_loop to accomodate non-soundstream checkpoints, see filename

### CoarseTransformer

In [5]:
wav2vec = HubertWithKmeans(
    checkpoint_path = f'./{hubert_ckpt}',
    kmeans_path = f'./{hubert_quantizer}'
)

soundstream = SoundStream(
    codebook_size = 1024,
    rq_num_quantizers = 8,
)

soundstream.load(f"./{soundstream_ckpt}")

coarse_transformer = CoarseTransformer(
    num_semantic_tokens = wav2vec.codebook_size,
    codebook_size = 1024,
    num_coarse_quantizers = 3,
    dim = 512,
    depth = 6
)

trainer = CoarseTransformerTrainer(
    transformer = coarse_transformer,
    soundstream = soundstream,
    wav2vec = wav2vec,
    folder = dataset_folder,
    batch_size = 1,
    data_max_length = 320 * 32,
    results_folder = "./coarse_results",
    save_results_every = 2,
    save_model_every = 4,
    num_train_steps = 9
)
# NOTE: I changed num_train_steps to 9 (aka 8 + 1) from 10000 to make things go faster for demo purposes
# adjusting save_*_every variables for the same reason

trainer.train()

NameError: name 'HubertWithKmeans' is not defined

### FineTransformer

In [6]:
soundstream = SoundStream(
    codebook_size = 1024,
    rq_num_quantizers = 8,
)

soundstream.load(f"./{soundstream_ckpt}")

fine_transformer = FineTransformer(
    num_coarse_quantizers = 3,
    num_fine_quantizers = 5,
    codebook_size = 1024,
    dim = 512,
    depth = 6
)

trainer = FineTransformerTrainer(
    transformer = fine_transformer,
    soundstream = soundstream,
    folder = dataset_folder,
    batch_size = 1,
    data_max_length = 320 * 32,
    num_train_steps = 9,
    results_folder = "./fine_results",
)
# NOTE: I changed num_train_steps to 9 (aka 8 + 1) from 10000 to make things go faster for demo purposes
# adjusting save_*_every variables for the same reason

trainer.train()

NameError: name 'SoundStream' is not defined

## Inference

In [21]:
# Everything together
audiolm = AudioLM(
    wav2vec = wav2vec,
    soundstream = soundstream,
    semantic_transformer = semantic_transformer,
    coarse_transformer = coarse_transformer,
    fine_transformer = fine_transformer
)

generated_wav = audiolm(batch_size = 1)

generating semantic:  15%|█▌        | 310/2048 [00:03<00:20, 85.27it/s]
generating coarse:  48%|████▊     | 244/512 [00:07<00:08, 31.48it/s]


KeyboardInterrupt: ignored

In [None]:
output_path = "out.wav"
sample_rate = 44100
torchaudio.save(output_path, generated_wav.cpu(), sample_rate)

In [26]:
!ls -l fine_results

total 98176
-rw-r--r-- 1 root root 100528755 Jan 31 20:33 fine.transformer.0.pt


In [1]:
!pip install boto3

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting boto3
  Downloading boto3-1.26.61-py3-none-any.whl (132 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m132.7/132.7 KB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting s3transfer<0.7.0,>=0.6.0
  Downloading s3transfer-0.6.0-py3-none-any.whl (79 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m79.6/79.6 KB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting botocore<1.30.0,>=1.29.61
  Downloading botocore-1.29.61-py3-none-any.whl (10.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.4/10.4 MB[0m [31m21.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting jmespath<2.0.0,>=0.7.1
  Downloading jmespath-1.0.1-py3-none-any.whl (20 kB)
Collecting urllib3<1.27,>=1.25.4
  Downloading urllib3-1.26.14-py2.py3-none-any.whl (140 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m140.6/140.6 KB[0m 

In [20]:
# # save to my personal bucket
# bucket = s3_resource.Bucket(name=bucket_name)
# saved_bucket = s3_resource.Bucket(name="hello-ok-zoomer")

# for i in range(21):
#     obj = bucket.Object(f"actual_soundstream_results/{i * 1000}.pt")
#     obj.download_file(f"{i * 1000}.pt")
#     saved_bucket.upload_file(Filename=f"{i * 1000}.pt", Key=f"actual_soundstream_results/{i * 1000}.pt")

In [52]:
soundstream = SoundStream(
    codebook_size = 1024,
    rq_num_quantizers = 8,
)

actual_num_train_steps = 20001
save_every = 1000

trainer = SoundStreamTrainer(
    soundstream,
    folder = dataset_folder,
    lr=3e-4,
    batch_size = 4,
    grad_accum_every = 8, # effective batch size of batch_size * grad_accum_every = 32
    data_max_length = 16000,
    results_folder = "soundstream_results",
    save_results_every = save_every,
    save_model_every = save_every,
    num_train_steps = actual_num_train_steps
).cuda()
trainer.load("actual_soundstream_results/20000.pt")

training with dataset of 2567 samples and validating with randomly splitted 136 samples


do you want to clear previous experiment checkpoints and results? (y/n)  y


In [53]:
self = trainer
steps = 20000
# self.soundstream.target_sample_hz = 24000
for i in range(10):
    for model, filename in ((self.ema_soundstream.ema_model, f'{steps}.ema'), (self.soundstream, str(steps))):
        model.eval()

        wave, = next(self.valid_dl_iter)
        wave = wave.to(self.device)

        recons = model(wave, return_recons_only = True)

        milestone = steps // self.save_results_every

        for ind, recon in enumerate(recons.unbind(dim = 0)):
            filename = str(self.results_folder / f'sample_{steps}_{i}.flac')
            torchaudio.save(filename, recon.cpu().detach(), self.soundstream.target_sample_hz)

In [49]:
self.soundstream.target_sample_hz
self.ds.max_length

AttributeError: 'Subset' object has no attribute 'max_length'