# (Nearly) One-Click Jukebox reprime notebook by Michaels Lab

Join the Jukebox community at https://discord.gg/aEqXFN9amV

Video guide by Broccaloo for the original notebook: https://vimeo.com/817968335

Explanation video by Michaels Lab about the notebook: https://youtu.be/BPo5sECkBV4?si=4niezv12iyElPn7J

Speed upsampling supported. Switch to upsample mode will happen automatically if data file is detected within the folder provided.

**How to handle memory problems:** In theory, this notebook is crafted to avoid Out of memory errors, but here's some tricks if you still encounter one:
* Restart runtime: At the top of the notebook, click "Runtime" and then "Restart runtime". Then run everything again. You should do this everytime you start a second run within the same session or after you've interrupted one.
* Decrease sample count: Choose a lower number for 'hps.n_samples'.

Notebook re-made by Mellck Borges (me) and inspired by Broccaloo

In [None]:
!nvidia-smi

Fri Nov 17 21:11:17 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.105.17   Driver Version: 525.105.17   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   42C    P8     9W /  70W |      0MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

the documentation within is written for the new jukebox-saveopt, it is recommended you keep this checkbox on

In [None]:
 use_new_jukebox_saveopt = True #@param {type: "boolean"}

# Install and setup

In [None]:
#@title Install and Import Jukebox
from google.colab import drive
drive.mount('/content/gdrive')

if use_new_jukebox_saveopt:
  !pip install --upgrade git+https://github.com/craftmine1000/jukebox-saveopt.git
else:
  !pip install --upgrade git+https://github.com/craftmine1000/jukebox-opt.git torch==1.8

import jukebox
import torch as t
import torch.nn.functional as F
import librosa
import os
from IPython.display import Audio
from jukebox.make_models import make_vqvae, make_prior, MODELS, make_model
from jukebox.hparams import Hyperparams, setup_hparams
from jukebox.sample import sample_single_window, _sample, \
                           sample_partial_window, upsample, \
                           load_prompts
from jukebox.utils.dist_utils import setup_dist_from_mpi
from jukebox.utils.torch_utils import empty_cache
rank, local_rank, device = setup_dist_from_mpi()

Mounted at /content/gdrive
Collecting git+https://github.com/craftmine1000/jukebox-saveopt.git
  Cloning https://github.com/craftmine1000/jukebox-saveopt.git to /tmp/pip-req-build-hau5j84q
  Running command git clone --filter=blob:none --quiet https://github.com/craftmine1000/jukebox-saveopt.git /tmp/pip-req-build-hau5j84q
  Resolved https://github.com/craftmine1000/jukebox-saveopt.git to commit 5b76e9e07eb15dba6fd79da99d76d1ecb32a7ea5
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting fire>=0.1.3 (from jukebox==1.0)
  Downloading fire-0.5.0.tar.gz (88 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m88.3/88.3 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting unidecode>=1.1.1 (from jukebox==1.0)
  Downloading Unidecode-1.3.7-py3-none-any.whl (235 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m235.5/235.5 kB[0m [31m8.4 MB/s[0m eta [36m0:00:00[0m
Collecting mpi4py>=

# Settings

## ---Setup Model

In [None]:
#@title Model Settings
model = "1b_lyrics" #@param ['5b_lyrics', '5b', '1b_lyrics']
hps = Hyperparams()
hps.sr = 44100
hps.n_samples = 2 #@param {type: "integer"}
hps.name = '/content/gdrive/My Drive/A.I/' #@param {type: "string"}
chunk_size = 128 if model=="5b_lyrics" else 128
#@markdown ## Level specific settings
#@markdown [level 0, level 1, level 2]
hps.hop_fraction = [1, 4, .125] #@param
batch_sizes = [8, 12, 4] #@param
hps.levels = 3

if not use_new_jukebox_saveopt:
  for i in range(2):
    if hps.hop_fraction[i] > 1:
      hps.hop_fraction[i] = 1

primer_length_in_seconds = librosa.get_duration(filename=hps.name + 'primer.wav')
duration = (int(primer_length_in_seconds * hps.sr) // 128) * 128
hps.sample_length = duration

raw_audio = load_prompts([hps.name + 'primer.wav'], duration, hps)

#print(raw_audio)
#print(raw_audio.shape)
#print(t.max(t.abs(raw_audio)))

try:
  try:
    zs = t.load(f'{hps.name}level_1/data.pth.tar')['zs']
  except:
    zs = t.load(f'{hps.name}tokens.t')
except:
  zs = [t.zeros(hps.n_samples,0,dtype=t.long, device='cpu') for _ in range(hps.levels)]

vqvae, *priors = MODELS[model]
vqvae = make_vqvae(setup_hparams(vqvae, dict(sample_length = 1048576)), device)
#vqvae.c_to(device)
if zs[-1].shape[1] < duration // 128:
  top_prior = make_prior(setup_hparams(priors[-1], dict()), vqvae, device)
  if use_new_jukebox_saveopt:
    top_prior.c_to(device)

	This alias will be removed in version 1.0.
  primer_length_in_seconds = librosa.get_duration(filename=hps.name + 'primer.wav')


Downloading from azure
Running  wget -O /root/.cache/jukebox/models/5b/vqvae.pth.tar https://openaipublic.azureedge.net/jukebox/models/5b/vqvae.pth.tar
Restored from /root/.cache/jukebox/models/5b/vqvae.pth.tar
0: Loading vqvae in eval mode
Creating cond. autoregress with prior bins [79, 2048], 
dims [384, 6144], 
shift [ 0 79]
input shape 6528
input bins 2127
Self copy is False
Loading artist IDs from /usr/local/lib/python3.10/dist-packages/jukebox/data/ids/v3_artist_ids.txt
Loading artist IDs from /usr/local/lib/python3.10/dist-packages/jukebox/data/ids/v3_genre_ids.txt
Level:2, Cond downsample:None, Raw to tokens:128, Sample length:786432
Downloading from azure
Running  wget -O /root/.cache/jukebox/models/1b_lyrics/prior_level_2.pth.tar https://openaipublic.azureedge.net/jukebox/models/1b_lyrics/prior_level_2.pth.tar
Restored from /root/.cache/jukebox/models/1b_lyrics/prior_level_2.pth.tar
0: Loading prior in eval mode


### model Instructions
the model selects what pre-trained model of jukebox you'd like to download and run.

Xb designates the number of parameters in the model in billions, so 5b and 5b_lyrics both have 5 billion parameters while 1b_lyrics has 1 billion.

1b_lyrics is faster to load and run but isnt as good as the 5b models.

### hps.n_sample Instructions
how many different songs you want to reprime in parallel.

### hps.name Instructions
the folder where jukebox will save its progress and output generated wav files, it is also here the primer is located

### Primer Instructions
the "primer" is the sound file that jukebox will generate new audio from, it should be located in the folder set by hps.name and its file name should be primer.wav

example; if hps.name is "somedir/moredirs/" then primer.wav should live there, aka: "somedir/moredirs/primer.wav"

### Levels Information
jukebox uses 3 different levels to represent sound, the first level, (level 0) is the least compressed(datawise) and the least noisy(audiowise), the third level (level 2) is the most compressed(datawise) and the most noisy(audiowise), the second level is unsurprisingly somewhere inbetween.

jukebox generates new audio on the third level (level 2), and it is then upsampled to the second and finally the first level.

### hps.hop_fraction Instructions
decides how much the context window slides over with each generation hop for that level.

if it is a integer and is at or above 1 it turns on speed-upsampling for that level, does not work on level 2.

when at or above 1 the fraction is actually used as a divisor for how much of the context will be used to upsample, the further into the context jukebox upsamples and generates the more compute it takes, this leads to a speed falloff at the end of each hop, this can be mitigated by using less of the context, hence this divisor.

furthermore, it is set to 1 at level 0 because even dividing by 2 creates very audible artifacts, level 1 is much less prone to it hence why its default at 4, 8 probably also works but i havent tried.


### batch_sizes Instructions
limits the maximum number of chunks/batches that are processed in parallel for that level, lower batch size = less ram usage, but also slower.

due to how continuous reprime works it does not require past context, which means level 2 can be generated in parallel chunks.

and due to speed-upsampling level 1 and 0 will also be processed in parallel chunks.

this means performance is now directly tied to batch size for all levels, up to a point of course,

## ---Sample Settings

In [None]:
lyrics_condition = """
'Tari, get over here
You're too slow!
It's, it's, it's, it's lookin' like I'm right
I-I-I, I-I-I
I hope that you takin' my side
Oh

(You're too slow!)
Try and keep up with me now
(You're too slow!)

Try and keep up with me now
Ke-keep up, keep up
Ke-ke-ke-keep up, keep up
Ke-ke-keep up, keep up
Ke-ke-ke-keep up, keep up
Keep, keep—
(—Too slow!)

There just ain't no one that stop me
Anyone can get they bitch taken
Speed of light, I'm him!
Before you even think you can get the chance
Just remember who put you on first!
Toxic, toxic, toxic—
(You're too slow!)
Try and keep up with me now
(You're too slow!)
Try and keep up with me now
Ke-keep up, keep up
Ke-ke-ke-keep up, keep up
Ke-ke-keep up, keep up
Ke-ke-ke-keep up, keep up
Keep, keep—

I'm always wantin' what I just can't have
Butterflies in my stomach 'cause you're just so bad
I, I, showed you wrong
I'ma shit on you

(You're too slow!)
Try and keep up with me now
(You're too slow!)
Try and keep up with me now
Ke-keep up, keep up
Ke-ke-ke-keep up, keep up
Ke-ke-keep up, keep up
Ke-ke-ke-keep up, keep up
Keep, keep—
It's lookin' like I'm right

I-I-I, I-I-I
I hope that you takin' my side
Oh

(You're too slow!)
Try and keep up with me now
(You're too slow!)
Try and keep up with me now
Ke-ke-keep up, keep up
Ke-ke-ke-keep up, keep up
Ke-ke-keep up, keep up
Ke-ke-ke-keep up, keep up
Keep, keep—

Toxic, toxic, toxic—
(You're too slow!)
"""

In [None]:
#@title ##---Sampling and upsampling
sampling_temperature = 0.97#@param {type: "number"}

artist_condition = 'unknown'#@param {type: "string"}
genre_condition = 'dance'#@param {type: "string"}

#@markdown ---

timing_mode = 'bpm'#@param ['tokens', 'seconds', 'bpm']

#@markdown Token/Seconds Mode

amount = 116#@param {type: "integer"}

#@markdown BPM Mode

bpm = 162#@param {type: "number"}
bars = 1#@param {type: "integer"}
beats = 0#@param {type: "integer"}

beats += bars * 4

if True:
  #zs = [t.zeros(hps.n_samples,0,dtype=t.long, device='cpu') for _ in range(hps.levels)]
  try:
    top_prior
  except:
    top_prior = None

  if top_prior:
    top_prior.prior.transformer.del_cache()
    first = True
    while zs[2].shape[1] < duration // top_prior.raw_to_tokens:
      left_to_sample = duration // top_prior.raw_to_tokens - zs[2].shape[1]

      amnt_to_sample = []
      zs_size_projected = zs[2].shape[1]
      for i in range(batch_sizes[2] // hps.n_samples):

        if timing_mode == 'tokens':
          new_tokens = amount
        elif timing_mode == 'seconds':
          seconds_done = round(zs_size_projected / (hps.sr / top_prior.raw_to_tokens))
          seconds_done += amount
          new_tokens = round(seconds_done * (hps.sr / top_prior.raw_to_tokens)) - zs_size_projected
        elif timing_mode == 'bpm':
          beats_done = round((zs_size_projected / (hps.sr / top_prior.raw_to_tokens)) / (60 / bpm))
          beats_done += beats
          new_tokens = round(beats_done * (60 / bpm) * (hps.sr / top_prior.raw_to_tokens)) - zs_size_projected

        assert new_tokens < top_prior.n_ctx, 'too long sections m8'

        amnt_to_sample.append(min(new_tokens, left_to_sample))
        left_to_sample -= new_tokens
        zs_size_projected += new_tokens
        if left_to_sample <= 0:
          break

      print('amnt_to_sample:', amnt_to_sample)
      max_to_sample = max(amnt_to_sample)
      #print('max_to_sample:', max_to_sample)
      print(zs[2].shape[1], '/', duration // top_prior.raw_to_tokens, '| aprox windows left:', round((duration // top_prior.raw_to_tokens - zs[2].shape[1]) / max_to_sample, 1))
      metas = []
      xs_pre_cat = []

      zs_size_projected = zs[2].shape[1]
      for i in range(len(amnt_to_sample)):
        amnt_to_encode = top_prior.n_ctx - max_to_sample
        start = zs_size_projected - amnt_to_encode
        metas.extend([dict(
            artist = artist_condition,
            genre = genre_condition,
            total_length = duration,
            offset = max(0, start) * top_prior.raw_to_tokens,
            lyrics = lyrics_condition
            )] * hps.n_samples
        )

        if zs_size_projected > 0:
          enc = raw_audio[:, max(0, start * top_prior.raw_to_tokens) : (start + top_prior.n_ctx) * top_prior.raw_to_tokens]
          xs = vqvae.encode(enc, bs_chunks=raw_audio.shape[0])
          #x = vqvae.decode(xs[2:], start_level=2).cpu().numpy()
          #for i in range(hps.n_samples):
          #  librosa.output.write_wav(f'{hps.name}top_level_encdec_{i}.wav', x[i][:], sr=hps.sr)

          #print('xs:', xs[2].shape)
          xss = xs[-1][:,:-max_to_sample]
          #print('xss:', xss.shape)
          #xss = xss[:, max(0,top_prior.n_ctx - (xss.shape[1] + max_to_sample)):]
          #print('xss:', xss.shape)
        else:
          xss = t.zeros(hps.n_samples,0,dtype=t.long, device='cpu')

        xs_pre_cat.append(xss)
        zs_size_projected += amnt_to_sample[i]

      labels = top_prior.labeller.get_batch_labels(metas, 'cuda')

      #print('metas:', metas)
      #print('labels:', labels)

      mx = max(map(lambda x: x.shape[1], xs_pre_cat))
      #print('mx:', mx)
      xs_pre_cat = list(map(lambda x: F.pad(x, (mx - x.shape[1], 0), mode='constant', value=0), xs_pre_cat))

      #print('xs_pre_cat:', list(map(lambda x: x.shape, xs_pre_cat)))

      xs = [
        t.zeros(hps.n_samples,0,dtype=t.long, device='cpu'),
        t.zeros(hps.n_samples,0,dtype=t.long, device='cpu'),
        t.cat(xs_pre_cat, dim=0)
      ]

      #print(xs[2].shape)
      #print(max_to_sample)
      sampling_kwargs = dict(temp=sampling_temperature, fp16=True, max_batch_size=batch_sizes[2], chunk_size=chunk_size)

      if use_new_jukebox_saveopt:
        xs=sample_partial_window(xs, labels, sampling_kwargs, 2, top_prior, max_to_sample, hps, autosave=False)
      else:
        xs=sample_partial_window(xs, labels, sampling_kwargs, 2, top_prior, max_to_sample, hps)

      foobar = []
      for i in range(len(amnt_to_sample)):
        foobar.append(xs[2][i*hps.n_samples:(i+1)*hps.n_samples,-max_to_sample:][:,:amnt_to_sample[i]])
      #print('foobar:', foobar)
      zs[2] = t.cat((zs[2], t.cat(foobar, dim=1)), dim=1)
      t.save(zs, f'{hps.name}tokens.t')
      first = False

  metas = [dict(artist = artist_condition,
                genre = genre_condition,
                total_length = duration,
                offset = 0,
                lyrics = lyrics_condition,
                ), ] * hps.n_samples









  if False: zs = t.load(f'{hps.name}tokens.t')

assert zs[2].shape[1]>=2048, f'Please first generate at least 2048 tokens at the top level, currently you have {zs[2].shape[1]}'
hps.sample_length = zs[2].shape[1]*128

# Set this False if you are on a local machine that has enough memory (this allows you to do the
# lyrics alignment visualization). For a hosted runtime, we'll need to go ahead and delete the top_prior
# if you are using the 5b_lyrics model.
if True:
  del top_prior
  empty_cache()
  top_prior=None

upsamplers = [make_prior(setup_hparams(prior, dict()), vqvae, 'cpu') for prior in priors[:-1]]
empty_cache()

sampling_kwargs = [dict(temp=0.995, fp16=True, max_batch_size=batch_sizes[0], chunk_size=128),
                    dict(temp=0.995, fp16=True, max_batch_size=batch_sizes[1], chunk_size=128),
                    None]

labels = [prior.labeller.get_batch_labels(metas, 'cuda') for prior in upsamplers] + [upsamplers[0].labeller.get_batch_labels(metas, 'cuda')]
empty_cache()

empty_cache()
for prior in upsamplers:
  prior.prior.transformer.del_cache()
empty_cache()
zs = upsample(zs, labels, sampling_kwargs, [*upsamplers, top_prior], hps)

#@markdown This cell will work for about 3-5 hours.

#@markdown It's not recommended if you put it in 5b_lyrics mode with a T4 GPU, because it will take a long couple of hours and will be use a lot of your GPU's RAM. If not, you have to opt for using 1b_lyrics.

amnt_to_sample: [689, 689]
0 / 15033 | aprox windows left: 21.8
Input genre alternative maps to the list ['alternative']. alternative is not present in /usr/local/lib/python3.10/dist-packages/jukebox/data/ids/v3_genre_ids.txt. Defaulting to (word_id, word) = (0, unknown), if that seems wrong please format genre correctly
Input genre alternative maps to the list ['alternative']. alternative is not present in /usr/local/lib/python3.10/dist-packages/jukebox/data/ids/v3_genre_ids.txt. Defaulting to (word_id, word) = (0, unknown), if that seems wrong please format genre correctly
Input genre alternative maps to the list ['alternative']. alternative is not present in /usr/local/lib/python3.10/dist-packages/jukebox/data/ids/v3_genre_ids.txt. Defaulting to (word_id, word) = (0, unknown), if that seems wrong please format genre correctly
Input genre alternative maps to the list ['alternative']. alternative is not present in /usr/local/lib/python3.10/dist-packages/jukebox/data/ids/v3_genre_ids.t