<a href="https://colab.research.google.com/github/Mllck/One-Click-Jukebox-Continuous-Repriming-by-Michaels-Lab/blob/main/One-Click%20Jukebox%20Continuous%20Repriming%20Method.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#(Almost) One-Click Jukebox Continuous Reprime notebook.

* Join the Jukebox community at https://discord.gg/aEqXFN9amV

* Video guide by Broccaloo for the Jukebox AI : https://vimeo.com/817968335

* Explanation video by Michaels Lab about his new notebook: https://youtu.be/BPo5sECkBV4?si=4niezv12iyElPn7J

####**How to handle memory problems:** In theory, this notebook is crafted to avoid Out of memory errors, but here's some tricks if you still encounter one:
* Restart runtime: At the top of the notebook, click "Runtime" and then "Restart runtime". Then run everything again. You should do this everytime you start a second run within the same session or after you've interrupted one.
* Decrease sample count: Choose a lower number for 'hps.n_samples'

#Guide to the below settings:

**your_lyrics:** Specify the lyrics Jukebox should attempt to follow. You can paste any lyrics you want in here or leave it blank, which will result in gibberish.

**model:**
OpenAI has trained a few different models for Jukebox. In this notebook, you can access the 5b_lyrics, 5b and 1b_lyrics models. As you can imagine, the 5b_lyrics model is the superior one, but also requires a stronger GPU to run properly. Which model you should choose depends on the GPU you were assigned, which you can check in the first cell of the notebook. Recommended settings: 5b_lyrics on P100 or T4 GPU, 1b_lyrics on K80 GPU.
(5b_lyrics theoretically works on a K80 now, but sampling is going to be super slow.)
(5b is like 5b_lyrics, without supporting custom lyrics, so it will generate gibberish vocals)

List of the v2 (5b_lyrics & 5b models) & v3 (1b_lyrics model) artist and genre:
https://github.com/openai/jukebox/tree/master/jukebox/data/ids

**hps.name:** Specifies the name of the folder in Google Drive, where you will find your results in. Make sure to choose a different name for each of your runs, or else the notebook will get confused.

**speed_upsampling:** If selected, will upsample much faster, at the cost of the samples sounding slightly "choppy".

**audio_file:** Specifies which song Jukebox will generate a new audio. Upload the file you want (can be WAV, FLAC, MP3, M4A, OPUS, etc.) to the root directory of your Google Drive and fill in its name above.

**sampling_temperature:** Determines the creativity and energy of Jukebox. The higher the temperature, the more chaotic and intense the result will be. You can experiment with this. Recommended to keep between 0.95 and 0.995


In [None]:
#@title ##---Check which GPU you were assigned by running this cell. { vertical-output: true, form-width: "32%" }
!nvidia-smi -L

GPU 0: Tesla T4 (UUID: GPU-31c84ff0-28c7-8a78-8fa3-59bf3499eee6)


In [None]:
your_lyrics = """"""

In [None]:
#@title ##---Main Code { vertical-output: true, form-width: "50%" }
use_new_jukebox_saveopt = True
from google.colab import drive
drive.mount('/content/gdrive')
if use_new_jukebox_saveopt:
  !pip install --upgrade git+https://github.com/craftmine1000/jukebox-saveopt.git
else:
  !pip install --upgrade git+https://github.com/craftmine1000/jukebox-saveopt.git

import jukebox
import torch as t
import torch.nn.functional as F
import librosa
import os
import numpy as np #Import numpy
from IPython.display import Audio
from jukebox.make_models import make_vqvae, make_prior, MODELS, make_model
from jukebox.hparams import Hyperparams, setup_hparams
from jukebox.sample import sample_single_window, _sample, \
                           sample_partial_window, upsample, \
                           load_prompts
from jukebox.utils.dist_utils import setup_dist_from_mpi
from jukebox.utils.torch_utils import empty_cache
rank, local_rank, device = setup_dist_from_mpi()

model = "5b_lyrics" #@param ['5b_lyrics', '5b', '1b_lyrics']
hps = Hyperparams()
hps.sr = 44100
hps.n_samples = 1
hps.name = '/content/gdrive/My Drive/' #@param {type: "string"}
chunk_size = 128 if model=="5b_lyrics" else 128
hps.hop_fraction = [1, 4, .125]
batch_sizes = [4, 6, 2] if model=="5b_lyrics" else [2, 4, 2]
hps.levels = 3

# Define max_batch_size here
max_batch_size = 8  # You might need to adjust this value based on your GPU memory

if not use_new_jukebox_saveopt:
  for i in range(2):
    if hps.hop_fraction[i] > 1:
      hps.hop_fraction[i] = 1

primer_length_in_seconds = librosa.get_duration(filename=hps.name + 'primer.wav')
duration = (int(primer_length_in_seconds * hps.sr) // 128) * 128
hps.sample_length = duration

raw_audio = load_prompts([hps.name + 'primer.wav'], duration, hps)

#print(raw_audio)
#print(raw_audio.shape)
#print(t.max(t.abs(raw_audio)))
# Inside the vqvae.decode function in jukebox/vqvae/vqvae.py
def decode(self, zs, start_level=0, end_level=None, bs_chunks=1, length_chunks=1): #Set a default value for length_chunks
# Modified logic to prevent length_chunks from being 0
    # Check if the input tensor is empty along dim=1
    if zs[-1].shape[1] == 0:  # If the last level's tensor has 0 length, return zeros
        return t.zeros(zs[0].shape[0], 0, device=zs[0].device, dtype=zs[0].dtype)

    length = zs[-1].shape[1]
    length_chunks = max(1, int(np.ceil(length / self.hps.sample_length)))

try:
  try:
    zs = t.load(f'{hps.name}level_1/data.pth.tar')['zs']
  except:
    zs = t.load(f'{hps.name}tokens.t')
except:
  zs = [t.zeros(hps.n_samples,0,dtype=t.long, device='cpu') for _ in range(hps.levels)]

vqvae, *priors = MODELS[model]
vqvae = make_vqvae(setup_hparams(vqvae, dict(sample_length = 1048576)), device)
#vqvae.c_to(device)
if zs[-1].shape[1] < duration // 128:
  top_prior = make_prior(setup_hparams(priors[-1], dict()), vqvae, device)
  if use_new_jukebox_saveopt:
    top_prior.c_to(device)


speed_upsampling = True  #@param {type: "boolean"}

#@markdown ---
# Assuming select_artist and select_genre are defined somewhere before this point
select_artist = "unknown"  # You need to define these variables or get them from user input
select_genre = "unknown"    # You need to define these variables or get them from user input
#zs=[t.zeros(hps.n_samples,0,dtype=t.long, device='cpu') for _ in range(hps.levels)]

sampling_temperature = 0.97#@param {type: "number"}
artist_condition = ''#@param {type: "string"}
genre_condition = ''#@param {type: "string"}

#@markdown ---

timing_mode = 'tokens'#@param ['tokens', 'bpm']

#@markdown ####Tokens/Seconds Mode

amount = 0#@param {type: "integer"}

#@markdown ####BPM Mode

bpm = 0#@param {type: "number"}
bars = 0#@param {type: "integer"}
beats = 0#@param {type: "integer"}

beats += bars * 4

if True:
  #zs = [t.zeros(hps.n_samples,0,dtype=t.long, device='cpu') for _ in range(hps.levels)]
  try:
    top_prior
  except:
    top_prior = None

  if top_prior:
    top_prior.prior.transformer.del_cache()
    first = True
    while zs[2].shape[1] < duration // top_prior.raw_to_tokens:
      left_to_sample = duration // top_prior.raw_to_tokens - zs[2].shape[1]

      amnt_to_sample = []
      zs_size_projected = zs[2].shape[1]
      for i in range(batch_sizes[2] // hps.n_samples):

        if timing_mode == 'tokens':
          new_tokens = amount
        elif timing_mode == 'seconds':
          seconds_done = round(zs_size_projected / (hps.sr / top_prior.raw_to_tokens))
          seconds_done += amount
          new_tokens = round(seconds_done * (hps.sr / top_prior.raw_to_tokens)) - zs_size_projected
        elif timing_mode == 'bpm':
          beats_done = round((zs_size_projected / (hps.sr / top_prior.raw_to_tokens)) / (60 / bpm))
          beats_done += beats
          new_tokens = round(beats_done * (60 / bpm) * (hps.sr / top_prior.raw_to_tokens)) - zs_size_projected

        assert new_tokens < top_prior.n_ctx, 'too long sections m8'

        amnt_to_sample.append(min(new_tokens, left_to_sample))
        left_to_sample -= new_tokens
        zs_size_projected += new_tokens
        if left_to_sample <= 0:
          break

      print('amnt_to_sample:', amnt_to_sample)
      max_to_sample = max(amnt_to_sample)
      #print('max_to_sample:', max_to_sample)
      print(zs[2].shape[1], '/', duration // top_prior.raw_to_tokens, '| aprox windows left:', round((duration // top_prior.raw_to_tokens - zs[2].shape[1]) / max_to_sample, 1))
      metas = []
      xs_pre_cat = []

      zs_size_projected = zs[2].shape[1]
      for i in range(len(amnt_to_sample)):
        amnt_to_encode = top_prior.n_ctx - max_to_sample
        start = zs_size_projected - amnt_to_encode
        metas.extend([dict(
            artist = artist_condition,
            genre = genre_condition,
            total_length = duration,
            offset = max(0, start) * top_prior.raw_to_tokens,
            lyrics = your_lyrics
            )] * hps.n_samples
        )

        if zs_size_projected > 0:
          enc = raw_audio[:, max(0, start * top_prior.raw_to_tokens) : (start + top_prior.n_ctx) * top_prior.raw_to_tokens]
          xs = vqvae.encode(enc, bs_chunks=raw_audio.shape[0])
          #x = vqvae.decode(xs[2:], start_level=2).cpu().numpy()
          #for i in range(hps.n_samples):
          #  librosa.output.write_wav(f'{hps.name}top_level_encdec_{i}.wav', x[i][:], sr=hps.sr)

          #print('xs:', xs[2].shape)
          xss = xs[-1][:,:-max_to_sample]
          #print('xss:', xss.shape)
          #xss = xss[:, max(0,top_prior.n_ctx - (xss.shape[1] + max_to_sample)):]
          #print('xss:', xss.shape)
        else:
          xss = t.zeros(hps.n_samples,0,dtype=t.long, device='cpu')

        xs_pre_cat.append(xss)
        zs_size_projected += amnt_to_sample[i]

      labels = top_prior.labeller.get_batch_labels(metas, 'cuda')

      #print('metas:', metas)
      #print('labels:', labels)

      mx = max(map(lambda x: x.shape[1], xs_pre_cat))
      #print('mx:', mx)
      xs_pre_cat = list(map(lambda x: F.pad(x, (mx - x.shape[1], 0), mode='constant', value=0), xs_pre_cat))

      #print('xs_pre_cat:', list(map(lambda x: x.shape, xs_pre_cat)))

      xs = [
        t.zeros(hps.n_samples,0,dtype=t.long, device='cpu'),
        t.zeros(hps.n_samples,0,dtype=t.long, device='cpu'),
        t.cat(xs_pre_cat, dim=0)
      ]

      #print(xs[2].shape)
      #print(max_to_sample)
      sampling_kwargs = dict(temp=sampling_temperature, fp16=True, max_batch_size=batch_sizes[2], chunk_size=chunk_size)

      if use_new_jukebox_saveopt:
        xs=sample_partial_window(xs, labels, sampling_kwargs, 2, top_prior, max_to_sample, hps, autosave=False)
      else:
        xs=sample_partial_window(xs, labels, sampling_kwargs, 2, top_prior, max_to_sample, hps)

      foobar = []
      for i in range(len(amnt_to_sample)):
        foobar.append(xs[2][i*hps.n_samples:(i+1)*hps.n_samples,-max_to_sample:][:,:amnt_to_sample[i]])
      #print('foobar:', foobar)
      zs[2] = t.cat((zs[2], t.cat(foobar, dim=1)), dim=1)
      t.save(zs, f'{hps.name}tokens.t')
      first = False

  metas = [dict(artist = artist_condition,
                genre = genre_condition,
                total_length = duration,
                offset = 0,
                lyrics = your_lyrics,
                ), ] * hps.n_samples

  if False: zs = t.load(f'{hps.name}tokens.t')

assert zs[2].shape[1]>=2048, f'Please first generate at least 2048 tokens at the top level, currently you have {zs[2].shape[1]}'
hps.sample_length = zs[2].shape[1]*128

# Set this False if you are on a local machine that has enough memory (this allows you to do the
# lyrics alignment visualization). For a hosted runtime, we'll need to go ahead and delete the top_prior
# if you are using the 5b_lyrics model.
if False:
  del top_prior
  empty_cache()
  top_prior=None

upsamplers = [make_prior(setup_hparams(prior, dict()), vqvae, 'cpu') for prior in priors[:-1]]
empty_cache()

sampling_kwargs = [dict(temp=0.975, fp16=True, max_batch_size=batch_sizes[0], chunk_size=128),
                    dict(temp=0.975, fp16=True, max_batch_size=batch_sizes[1], chunk_size=128),
                    None]

labels = [prior.labeller.get_batch_labels(metas, 'cuda') for prior in upsamplers] + [upsamplers[0].labeller.get_batch_labels(metas, 'cuda')]
empty_cache()

empty_cache()
for prior in upsamplers:
  prior.prior.transformer.del_cache()
empty_cache()
zs = upsample(zs, labels, sampling_kwargs, [*upsamplers, top_prior], hps)
disconnect_runtime_after_finish = True #@param {type: "boolean"}
if disconnect_runtime_after_finish == True:
  from google.colab import runtime
  runtime.unassign()
#@markdown ---
#@markdown ####This cell will work for about 3-5 hours.

#@markdown ####It's not recommended if you put it in 5b_lyrics model with a T4 GPU, because it will take a long couple of hours and will be use a lot of your GPU's RAM. If not, you have to opt for using 1b_lyrics.

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).
Collecting git+https://github.com/craftmine1000/jukebox-saveopt.git
  Cloning https://github.com/craftmine1000/jukebox-saveopt.git to /tmp/pip-req-build-e0h0fn14
  Running command git clone --filter=blob:none --quiet https://github.com/craftmine1000/jukebox-saveopt.git /tmp/pip-req-build-e0h0fn14
  Resolved https://github.com/craftmine1000/jukebox-saveopt.git to commit 5b76e9e07eb15dba6fd79da99d76d1ecb32a7ea5
  Preparing metadata (setup.py) ... [?25l[?25hdone
Using cuda True


	This alias will be removed in version 1.0.
  primer_length_in_seconds = librosa.get_duration(filename=hps.name + 'primer.wav')
  zs = t.load(f'{hps.name}level_1/data.pth.tar')['zs']
  zs = t.load(f'{hps.name}tokens.t')


Downloading from azure


  checkpoint = t.load(restore, map_location=memory_map)


Restored from /root/.cache/jukebox/models/5b/vqvae.pth.tar
0: Loading vqvae in eval mode
Loading artist IDs from /usr/local/lib/python3.11/dist-packages/jukebox/data/ids/v2_artist_ids.txt
Loading artist IDs from /usr/local/lib/python3.11/dist-packages/jukebox/data/ids/v2_genre_ids.txt
Level:2, Cond downsample:None, Raw to tokens:128, Sample length:1048576
Downloading from azure


###Important links:

* Official blog: https://openai.com/blog/jukebox/
* Original repo: https://github.com/openai/jukebox/

* License: Non-commercial, for details see: https://github.com/openai/jukebox/blob/master/LICENSE

* The original notebook was created by: Jaime v2.0 - Since 2018. (https://www.youtube.com/channel/UCWbk5lrnDGB6SnhnIwcDZ4w)