# **Mubert Text to Music ✍ ➡ 🎹🎵🔊**

* A simple notebook demonstrating prompt-based music generation via [Mubert](https://mubert.com) [API](https://mubert2.docs.apiary.io/)

* Original Github here: https://github.com/MubertAI/Mubert-Text-to-Music

* **Updated by [@3chain](http://twitter.com/web3chain) to add streaming of any Mubert Category, Group, or Channel.**

**TODO:**
* **Use prompt to pick stream.**

# 1. SETUP - AUTOMATIC

In [None]:
#@title ## 1.1 Setup Environment

import subprocess, time
print("Setting up environment...")
start_time = time.time()
all_process = [
    ['pip', 'install', 'torch==1.12.1+cu113', 'torchvision==0.13.1+cu113', '--extra-index-url', 'https://download.pytorch.org/whl/cu113'],
    ['pip', 'install', '-U', 'sentence-transformers'],
    ['pip', 'install', 'httpx'],
]
for process in all_process:
    running = subprocess.run(process,stdout=subprocess.PIPE).stdout.decode('utf-8')

end_time = time.time()
print(f"Environment set up in {end_time-start_time:.0f} seconds")

In [None]:
#@title ## 1.2 Setup custom functions

def print_html(text,font_family='Montserrat',paragraph_type='h1', font_color='white'):
  from IPython.display import HTML
  html_text = f"<html><head><style>@importurl(https://fonts.googleapis.com/css?family={font_family});{paragraph_type}" + '{' + f"font-family:'{font_family}',serif;color:{font_color};" + '}' + f"</style></head><body><{paragraph_type}>{text}</{paragraph_type}></body></html>"
  #print(f'text = {text} \nfont_family={font_family}, paragraph_type={paragraph_type}, font_color={font_color} \nhtml_text string is: \n{html_text}')
  html = HTML(data=html_text)
  display(html)

# TODO: TRYING TO DISPLAY JSON AS A DIAGRAM - CURRENTLY NOT WORKING
def display_json_as_diagram(json_data):
  edges = []
  def get_edges(treedict, parent=None):
      name = next(iter(treedict.keys()))
      if parent is not None:
          edges.append((parent, name))
      for item in treedict[name]["children"]:
          if isinstance(item, dict):
              get_edges(item, parent=name)
          else:
              edges.append((name, item))
  get_edges(json_data)
  # Dump edge list in Graphviz DOT format
  print('strict digraph tree {')
  for row in edges:
      print('    {0} -> {1};'.format(*row))
  print('}')


In [None]:
#@title ## 1.3 Setup Mubert API

import numpy as np
from sentence_transformers import SentenceTransformer
from IPython.display import Audio, display
import httpx
import json

# MiniLM is a family of small pre-trained language models created by Microsoft. See https://github.com/microsoft/unilm/tree/master/minilm 
# The model used in this Notebook is https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2 
minilm = SentenceTransformer('all-MiniLM-L6-v2') 

# Endpoints for API calls are specified in the Mubert documentation: https://mubert2.docs.apiary.io/
# GLOBAL VARIABLES (constants)
RECORD_TRACK_ENDPOINT = 'https://api-b2b.mubert.com/v2/RecordTrackTTM'  # make sure there's no '/' at the end..
RECORD_TRACK_METHOD = 'RecordTrackTTM'
GET_CHANNELS_ENDPOINT = 'https://api-b2b.mubert.com/v2/GetPlayMusic'
GET_CHANNELS_METHOD = 'GetPlayMusic'

# The following tags string can also be retrieved via the Mubert endpoint https://api-b2b.mubert.com/v2/GetPlayMusic
MUBERT_CUSTOM_TAGS_STRING = 'tribal,action,kids,neo-classic,run 130,pumped,jazz / funk,ethnic,dubtechno,reggae,acid jazz,liquidfunk,funk,witch house,tech house,underground,artists,mystical,disco,sensorium,r&b,agender,psychedelic trance / psytrance,peaceful,run 140,piano,run 160,setting,meditation,christmas,ambient,horror,cinematic,electro house,idm,bass,minimal,underscore,drums,glitchy,beautiful,technology,tribal house,country pop,jazz & funk,documentary,space,classical,valentines,chillstep,experimental,trap,new jack swing,drama,post-rock,tense,corporate,neutral,happy,analog,funky,spiritual,sberzvuk special,chill hop,dramatic,catchy,holidays,fitness 90,optimistic,orchestra,acid techno,energizing,romantic,minimal house,breaks,hyper pop,warm up,dreamy,dark,urban,microfunk,dub,nu disco,vogue,keys,hardcore,aggressive,indie,electro funk,beauty,relaxing,trance,pop,hiphop,soft,acoustic,chillrave / ethno-house,deep techno,angry,dance,fun,dubstep,tropical,latin pop,heroic,world music,inspirational,uplifting,atmosphere,art,epic,advertising,chillout,scary,spooky,slow ballad,saxophone,summer,erotic,jazzy,energy 100,kara mar,xmas,atmospheric,indie pop,hip-hop,yoga,reggaeton,lounge,travel,running,folk,chillrave & ethno-house,detective,darkambient,chill,fantasy,minimal techno,special,night,tropical house,downtempo,lullaby,meditative,upbeat,glitch hop,fitness,neurofunk,sexual,indie rock,future pop,jazz,cyberpunk,melancholic,happy hardcore,family / kids,synths,electric guitar,comedy,psychedelic trance & psytrance,edm,psychedelic rock,calm,zen,bells,podcast,melodic house,ethnic percussion,nature,heavy,bassline,indie dance,techno,drumnbass,synth pop,vaporwave,sad,8-bit,chillgressive,deep,orchestral,futuristic,hardtechno,nostalgic,big room,sci-fi,tutorial,joyful,pads,minimal 170,drill,ethnic 108,amusing,sleepy ambient,psychill,italo disco,lofi,house,acoustic guitar,bassline house,rock,k-pop,synthwave,deep house,electronica,gabber,nightlife,sport & fitness,road trip,celebration,electro,disco house,electronic'

def play_audio_from_url(url, maximum_iterations=20, autoplay=True):
  # I guess this loop keeps going until all the audio has been downloaded and 
  # displayed (i.e. played). Notice how it iterates up to a maximum times of 'maxit', which defaults to 20.
  for i in range(maximum_iterations):
    r = httpx.get(url)                           # retrieve the mp3 ***NOTE: this isn't necessary to get the actual data, it's just to check we have a 200 code before we create an 'Audio' object
    if r.status_code == 200:
        display(Audio(url, autoplay=autoplay))   # play the mp3
        break
    time.sleep(1)
    print('.', end='')

# this function uses the RecordTrackTTM endpoint/method, which weirdly is not specified in the API
# linked at the top of this Notebook. There's an endpoint/method just called "RecordTrack" specified.
# 'RecordTracTTM' does work fine however!
def get_track_by_tags(tags, pat, duration, maxit=20, autoplay=False, loop=False):
  global RECORD_TRACK_ENDPOINT
  global RECORD_TRACK_METHOD
  if loop:
    mode = "loop"
  else:
    mode = "track"
  # NOTE: In the original version of this Notebook, this is the only API call used!
  #       There is another HTTP-GET later on, but that's just to retrieve the mp3 with 
  #       the unique URL returned by the RecordTrack(TTM) method 
  r = httpx.post(f'{RECORD_TRACK_ENDPOINT}',
      json={
          "method": RECORD_TRACK_METHOD,
          "params": {
              "pat": pat, 
              "duration": duration,
              "tags": tags,
              "mode": mode
          }
      })

  rdata = json.loads(r.text)
  #breakpoint() #DEBUG
  #display_json_as_diagram(rdata) #TODO for displaying channels as a tree diagram - currently not working 
  assert rdata['status'] == 1, rdata['error']['text']
  trackurl = rdata['data']['tasks'][0]['download_link']  # get the download link URL
  print('Generating track ', end='')
  print(trackurl) # DEBUG
  play_audio_from_url(trackurl, autoplay = autoplay, maximum_iterations = maxit)                             # play audio from url
  

''' FUNCTION: find_similar
    This function uses math to compare your prompt to the complete list of Mubert tags 
    to find the tags with the best match to your prompt.
    (You CANNOT decide how many tags you want HERE. 
     But you can do that in the 'get_tags_for_prompts' function below by changing the 
     value of the argument passed to the 'top_n' parameter.                 '''  
def find_similar(em, embeddings, method='cosine'):
  scores = []
  for ref in embeddings:
      if method == 'cosine': 
          scores.append(1 - np.dot(ref, em)/(np.linalg.norm(ref)*np.linalg.norm(em)))
      if method == 'norm': 
          scores.append(np.linalg.norm(ref - em))
  return np.array(scores), np.argsort(scores)

def get_tags_for_prompts(prompts, top_n=3, debug=False):
  global MUBERT_CUSTOM_TAGS_STRING
  mubert_tags = np.array(MUBERT_CUSTOM_TAGS_STRING.split(',')) # Split Mubert tags
  mubert_tags_embeddings = minilm.encode(mubert_tags)   # Encode Mubert tags with the MiniLM language model 
  prompts_embeddings = minilm.encode(prompts)           # Encode prompts  with the MiniLM language model 
  ret = []
  for i, pe in enumerate(prompts_embeddings):
      scores, idxs = find_similar(pe, mubert_tags_embeddings) # compares Mubert tags to the prompts you provided and scores them based on similarity (see function above)
      top_tags = mubert_tags[idxs[:top_n]]                    # selects the top X tags depending on how many specified as an argument for parameter 'top_n' in this function
      top_prob = 1 - scores[idxs[:top_n]]                     # gets the match score (probability) associated with each top X tag (just for displaying - see below)
      if debug:
          print(f"Prompt: {prompts[i]}\nTags: {', '.join(top_tags)}\nScores: {top_prob}\n\n\n")   # tells the user what the top X tags are and their scores (match probability) 
      ret.append((prompts[i], list(top_tags)))
  return ret

In [None]:
#@title ## 1.4 Setup Mubert channels

mubert_channels = ""
def load_all_mubert_channels(pat):
  global mubert_channels
  r = httpx.post(f'https://api-b2b.mubert.com/v2/GetPlayMusic',json={"method": 'GetPlayMusic', "params": {"pat": pat}})
  #print(r.text)
  get_play_music_api_json_response = json.loads(r.text)
  #print(channels_dict)
  #print(json.dumps(get_play_music_api_json_response, indent=4)))  # DEBUG **CAREFUL** LARGE OUTPUT: PRETTY PRINT JSON REPONSE :)
  mubert_channels = get_play_music_api_json_response['data']['categories']
pat = 'dHRtLjE3MzI5NTEzLjQ5NTFmNjQyOGU4MzE3MmE0ZjM5ZGUwNWQ1YjNhYjEwZDU4NTYwYjguMS4z.daa808932c238f6a58a0b82b972d82ba8c870232ada3e2a95024a1e364539ba4'
load_all_mubert_channels(pat)



In [None]:
#@title # 2. SETUP - USER INPUT NEEDED
#@markdown ## 2.1 Enter **either** your Mubert registered email **or** PAT access token
##@markdown This section receives a PAT access token then gets the channels   

def get_user_pat_token():
  import httpx
  import json
  
  mubert_registered_email = "" #@param {type:"string"}
  user_pat_token = '' #@param{type:'string'}
  #@markdown <br>
  #@markdown
  #@markdown ## 👆 **Notes about these parameters** 👆
  #@markdown * <code><font size="5">mubert_registered_email</font></code><font size=4> - Email used to register on Mubert 
  #@markdown
  #@markdown * <code><font size="5">user_pat_token</font></code><font size="4"> - If you already have a pat token you can just enter it here. <br><font size=3>*NOTE: this is not secure, but anyone who knows the email address can also retrieve this token.*
  
  if user_pat_token == "":
    email = mubert_registered_email
    r = httpx.post('https://api-b2b.mubert.com/v2/GetServiceAccess', 
        json={
            "method":"GetServiceAccess",
            "params": {
                "email": email,
                "license":"ttmmubertlicense#f0acYBenRcfeFpNT4wpYGaTQIyDI4mJGv5MfIhBFz97NXDwDNFHmMRsBSzmGsJwbTpP1A6i07AXcIeAHo5",
                "token":"4951f6428e83172a4f39de05d5b3ab10d58560b8",
                "mode": "loop"
                }})
    rdata = json.loads(r.text)
    assert rdata['status'] == 1, "probably incorrect e-mail"
    pat = rdata['data']['pat']
    print_html(f'Got your personal Mubert PAT token! 👇 \n{pat}', paragraph_type='h3')
    print_html(f'\n👆 You can save this token for next time 👆')
    return pat
  else:
    print_html(f'This is the Mubert PAT token you entered 👇')
    print()
    print_html(f'{user_pat_token}', paragraph_type='h3')
    return user_pat_token

# Now run it!
user_pat_token = get_user_pat_token()

In [None]:
#@title ## 2.2 Grab a link from the Mubert streams page
##@markdown ## *(you will need to copy the link by hand once the page has loaded !)* 

def get_pat_from_mubert_streamers_page():
  from IPython.display import IFrame
  from IPython.display import clear_output 
  from urllib.parse import urlparse
  from urllib.parse import parse_qs
  
  display(IFrame('https://streamers.mubert.com/', width=600, height=440))
  copy_mubert_link_instructions = '👆 Wait until the Mubert web page has loaded in the frame above.. 👆 then.. <br><ol><li>press "Copy URL" next to the first playlist you see<br><li>copy the link with the copy button <br> <li>paste the link into the field below so we can get streamz!!'
  text_with_link_to_streams_page = "(btw if you want to check out the streams page here's the link: <a href='https://streamers.mubert.com/'>https://streamers.mubert.com</a>). But this Notebook can access many more streams!.."
  print_html(copy_mubert_link_instructions)
  print_html(text_with_link_to_streams_page, paragraph_type='h2')
  print_html('👇 Paste link you copied here and hit enter 👇')
  usr_input = input('\n')
  # clear outputs first..
  clear_output()
  #mubert_stream_url = "https://stream.mubert.com/b2b/v2?playlist=6.5&intensity=medium&pat=bXViZXJ0Zm9yc3RyZWFtZXJzLjE3MzY0NTQwLjAzZTgxNTg5NzJmMWMzM2IxM2Y2ZDVlOWQ2ZWI3MTdkYTNkOTM3NTcuMS4z.10c5849e0e4f405cb163fdbefe43db9d61da5b2b6ecbae63168b34b9f836cd23" #@param{type:'string'}
  mubert_stream_url = usr_input
  parsed_url = urlparse(mubert_stream_url)
  pat = parse_qs(parsed_url.query)['pat'][0]
  print_html(f'PAT token extracted from the link')
  print_html(f'{pat}', paragraph_type='h3')
  #print(f'PAT token extracted from the link is: \n{pat}')
  return pat

# Now run it!
streaming_pat_token = get_pat_from_mubert_streamers_page()

In [None]:
#@title # 3. SETUP - MUBERT CHANNELS (for playing streams)
#@markdown ## 3.1 Load all Mubert channels (from Mubert API)

display_list_of_mubert_channels = True #@param{type:'boolean'} 

#@markdown ----
#@markdown # **LFG!! - Pick a playlist ID from the table below and enter it in section 4.2 to play that stream. (here some more infos)** 

#@markdown * #### In the table, the ID in square brackets next to each item is its Mubert playlist ID.
#@markdown * #### For example the 'Pumped' Channel can be played with ID: **0.1.0** <br>
#@markdown * #### **NOT ONLY Channels can be played. Groups and Categories can also be played. <br>When you use a Category or Group ID, then the AI uses ALL the underlying channels to compose the stream! So for example:**
#@markdown * #### the Category 'Moods' can be played with ID: **0**
#@markdown * #### the Group 'Energizing' can be played with with ID: **0.1**
#@markdown <br>


#@markdown #### **After choosing a playlist you can remove the table by clicking the little symbol next to the 'Category' column header)**

list_of_mubert_categories = []
list_of_mubert_groups = []
list_of_mubert_channels = []

def load_all_mubert_channels(pat):
  global list_of_mubert_categories
  global list_of_mubert_groups
  global list_of_mubert_channels
  r = httpx.post(f'https://api-b2b.mubert.com/v2/GetPlayMusic',json={"method": 'GetPlayMusic', "params": {"pat": pat}})
  #print(r.text)
  get_play_music_api_json_response = json.loads(r.text)
  #print(channels_dict)
  #print(json.dumps(get_play_music_api_json_response, indent=4)))  # DEBUG **CAREFUL** LARGE OUTPUT: PRETTY PRINT JSON REPONSE :)
  mubert_channels = get_play_music_api_json_response['data']['categories']
  return mubert_channels

def display_list_of_mubert_channels():
  # note: mubert_channels is a list of dictionaries. Each dict is a 'category'.
  #       each category contains a list of 'group' dictionaries
  #       each 'group' dict contains a list of 'channel' dicts
  table_format = '{:<25} {:<25} {:<25}'
  print(table_format.format(f'Category [playlist ID]', 'Group [playlist ID]', 'Channel [playlist ID]'))
  for category in mubert_channels:
    list_of_mubert_categories.append(category["name"])
    for group in category["groups"]:
      list_of_mubert_groups.append(group["name"])
      for channel in group["channels"]:
        list_of_mubert_channels.append(channel["name"])
        print(table_format.format(
            f'[{category["playlist"]}] {category["name"]}', 
            f'[{group["playlist"]}] {group["name"]}', 
            f'[{channel["playlist"]}] {channel["name"]}'
            )
        )

# Now run it!
mubert_channels = load_all_mubert_channels(pat)
if display_list_of_mubert_channels:
  display_list_of_mubert_channels()


# 4. PLAY STUFF!

In [None]:
#@title ## 4.1 🎵 Generate Music! (mp3 format) 🎵

skip_section = False #@param{type:'boolean'}
recommend_stream_from_prompt = True #@param{type:'boolean'}

def generate_track_by_prompt(prompt, duration, loop=False):
  tags = get_tags_for_prompts([prompt,])[0][1]
  #print(f'tags: {tags}') #DEBUG
  try:
    get_track_by_tags(tags, pat, duration, autoplay=True, loop=loop)
  except Exception as e:
    print(str(e))
  print('\n')

if not skip_section:
  prompt = 'jump up drum and bass meets jazz in a shady bar in soho' #@param {type:"string"}
  duration = 30 #@param {type:"number"}
  loop = True #@param {type:"boolean"}
  generate_track_by_prompt(prompt, duration, loop)

#@markdown ---
#@markdown ### 👆 **Notes about these parameters** 👆
#@markdown * `prompt` can have maximum 256 words.
#@markdown \**This is a limitation of the all-MiniLM-L6-v2 'sentence transformers' model.
#@markdown More words and the prompt gets 'truncated' - i.e. words after nr 250 are ignored.*
  


In [None]:
#@title ## 4.2 🎵 Play Stream (can't be downloaded) 🎵 
#@markdown #What's so cool about this?!
#@markdown * ## Mainly - from [streamers.mubert.com](https://streamers.mubert.com) only a few streams can be played, but with this, any Mubert Categroy, Group, or Channel that you can find in the Channels list (when you run code block 3.1 above) can be played!

#@markdown # **NOTES:**
#@markdown * ## **Stream is different every time this block runs**
#@markdown * ## **The stream cannot be downloaded**
#@markdown * ## **You could record it with something like [OBS Studio](https://obsproject.com/download)**
#@markdown ----
from IPython.display import IFrame
from IPython.display import Video

skip_section = False #@param{type:'boolean'}
playlist_id = "5.0.2" #@param{type:'string'}
intensity = "high" #@param['high','medium','low'] {type:'string'}

if not skip_section:
  stream_url = f'https://stream.mubert.com/b2b/v2?playlist={playlist_id}&intensity={intensity}&pat={streaming_pat_token}'
  print_html(f'<a href={stream_url}>Here is the link to the stream!</a> (or just play below)',paragraph_type='h3')

  #display(IFrame(stream_url, width=300, height=150)) # works, but then discovered can play the audio via 'Video' object
  display(Video(data=f'{stream_url}',height=50,width=300))

In [None]:
#@title ## 4.3 Batch Music (mp3) generation 🎶
#@markdown ## Use multiple prompts to get multiple downloads!

skip_on_run_all = True #@param{type:'boolean'}

if not skip_on_run_all:
  duration = 60 #@param{type:'integer'}

  prompts = [
      'kind beaver guards life tree, stan lee, epic',
      'astronaut riding a horse',
      'winnie the pooh cooking methamphetamine',
      'vladimir lenin smoking weed with bob marley',
      'soviet retrofuturism',
      'two wasted friends high on weed are trying to navigate their way to their hostel in a big city, night, trippy',
      'an elephant levitating on a gas balloon',
      'calm music',
      'a refrigerator floating in a pond'
  ]

  tags = get_tags_for_prompts(prompts)

  for i, tag in enumerate(tags):
    print(f'Prompt: {tag[0]}\nTags: {tag[1]}')
    try:
      get_track_by_tags(tag[1], pat, duration, autoplay=False)
    except Exception as e:
      print(str(e))
    print('\n')

else:
  print(f'batch section skipped because \'skip_on_run_all\' = {skip_on_run_all}')